Skip to content

Commit

Permalink
[easy] [XLA][HostOffloading] Separate utils used in Host Offloader
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663034665
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 14, 2024
1 parent 2e8be6d commit aa23400
Show file tree
Hide file tree
Showing 6 changed files with 539 additions and 247 deletions.
53 changes: 53 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6398,6 +6398,58 @@ xla_cc_test(
],
)

cc_library(
name = "host_offload_utils",
srcs = ["host_offload_utils.cc"],
hdrs = ["host_offload_utils.h"],
deps = [
":call_graph",
":hlo_buffer",
":host_memory_offload_annotations_hdr",
":pattern_matcher",
"//xla:literal_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "host_offload_utils_test",
srcs = ["host_offload_utils_test.cc"],
deps = [
":hlo_verifier",
":host_memory_offload_annotations_hdr",
":host_offload_utils",
":pattern_matcher",
":pattern_matcher_gmock",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "host_offloader",
srcs = ["host_offloader.cc"],
Expand All @@ -6410,6 +6462,7 @@ cc_library(
":hlo_pass",
":hlo_value",
":host_memory_offload_annotations_hdr",
":host_offload_utils",
":pattern_matcher",
"//xla:literal_util",
"//xla:shape_util",
Expand Down
243 changes: 243 additions & 0 deletions xla/service/host_offload_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/host_offload_utils.h"

#include <array>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/call_graph.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/shape_util.h"
#include "xla/util.h"

namespace xla {
namespace host_offload_utils {

namespace {

using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget;
using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget;

bool CustomCallReusesBuffer(const HloInstruction* custom_call,
int64_t operand_index) {
if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget ||
custom_call->custom_call_target() == kMoveToHostCustomCallTarget) {
// Does not define a new buffer.
return true;
}
// Check the custom call's output_to_operand_aliasing.
const std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>&
aliases = custom_call->output_operand_aliasing();
for (const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>& alias :
aliases) {
int64_t alias_operand_index = alias.second.first;
if (alias_operand_index == operand_index) {
// This operand aliases with the output.
return true;
}
}
// By default, assume custom calls define new buffers.
return false;
}

} // namespace

absl::StatusOr<std::vector<InstructionAndShapeIndex>> GetSuccessors(
const InstructionAndShapeIndex& instruction_and_shape_index) {
std::vector<InstructionAndShapeIndex> result;
HloInstruction* instruction = instruction_and_shape_index.instruction;
if (instruction->IsRoot()) {
// Successor of the root is the call instruction(s).
std::unique_ptr<CallGraph> call_graph =
CallGraph::Build(instruction->GetModule());
auto callers = call_graph->GetComputationCallers(instruction->parent());
for (HloInstruction* caller : callers) {
result.push_back({caller, instruction_and_shape_index.shape_index});
}
}
for (HloInstruction* user : instruction->users()) {
if (user->opcode() == HloOpcode::kTuple) {
auto operand_indices = user->OperandIndices(instruction);
for (const auto i : operand_indices) {
auto tmp_shape_index = instruction_and_shape_index.shape_index;
tmp_shape_index.push_back(i);
result.push_back({user, std::move(tmp_shape_index)});
}
} else if (user->opcode() == HloOpcode::kGetTupleElement) {
ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index;
const auto index = tmp_shape_index.front();
if (index == user->tuple_index()) {
// This GTE is for the buffer we're tracking.
tmp_shape_index.pop_front();
result.push_back({user, std::move(tmp_shape_index)});
}
} else if (user->opcode() == HloOpcode::kCall) {
auto operand_indices = user->OperandIndices(instruction);
CHECK(user->called_computations().size() == 1)
<< "Expect call to only have one called computation.";
for (const auto i : operand_indices) {
HloComputation* called_computation =
user->called_computations().front();
HloInstruction* parameter_instruction =
called_computation->parameter_instruction(i);
result.push_back(
{parameter_instruction, instruction_and_shape_index.shape_index});
}
} else if (user->opcode() == HloOpcode::kWhile) {
auto operand_indices = user->OperandIndices(instruction);
HloComputation* while_body_computation = user->while_body();
HloComputation* while_condition_computation = user->while_condition();
for (const auto i : operand_indices) {
HloInstruction* parameter_instruction =
while_body_computation->parameter_instruction(i);
result.push_back(
{parameter_instruction, instruction_and_shape_index.shape_index});

HloInstruction* condition_instruction =
while_condition_computation->parameter_instruction(i);
result.push_back(
{condition_instruction, instruction_and_shape_index.shape_index});
}
} else if (user->opcode() == HloOpcode::kAsyncStart) {
auto operand_indices = user->OperandIndices(instruction);
CHECK(user->called_computations().size() == 1)
<< "Expect async-start to only have one called computation.";
for (const auto i : operand_indices) {
HloComputation* called_computation =
user->called_computations().front();
HloInstruction* parameter_instruction =
called_computation->parameter_instruction(i);
result.push_back(
{parameter_instruction, instruction_and_shape_index.shape_index});
}
} else if (user->opcode() == HloOpcode::kCustomCall) {
const auto operand_indices = user->OperandIndices(instruction);
// TODO(b/342650757): Rather than a boolean indicating whether the
// instruction reuses the buffer, return the shape index of the output
// that the operand aliases with.
bool found_one = false;
for (const auto i : operand_indices) {
if (CustomCallReusesBuffer(user, i)) {
if (found_one) {
return absl::InternalError(
"Found multiple operands of a custom call that reuse the same "
"output buffer.");
}
result.push_back({user, instruction_and_shape_index.shape_index});
found_one = true;
}
}
} else {
result.push_back({user, instruction_and_shape_index.shape_index});
}
}
return result;
}

std::vector<InstructionAndShapeIndex> GetPredecessors(
const InstructionAndShapeIndex& instruction_and_shape_index) {
std::vector<InstructionAndShapeIndex> result;
HloInstruction* instruction = instruction_and_shape_index.instruction;
if (instruction->opcode() == HloOpcode::kGetTupleElement) {
const int64_t index = instruction->tuple_index();
auto tmp_shape_index = instruction_and_shape_index.shape_index;
tmp_shape_index.push_front(index);
result.push_back({instruction->mutable_operand(0), tmp_shape_index});
} else if (instruction->opcode() == HloOpcode::kTuple) {
CHECK(!instruction_and_shape_index.shape_index.empty())
<< "Did not store an index before encountering a tuple.";
auto tmp_shape_index = instruction_and_shape_index.shape_index;
const int64_t index = tmp_shape_index.front();
tmp_shape_index.pop_front();
result.push_back({instruction->mutable_operand(index), tmp_shape_index});
} else if (instruction->opcode() == HloOpcode::kCall) {
// Predecessor of a call is its computation's root instruction.
CHECK(instruction->called_computations().size() == 1)
<< "Expect call to only have one called computation.";
HloComputation* called_computation =
instruction->called_computations().front();
result.push_back({called_computation->root_instruction(),
instruction_and_shape_index.shape_index});
} else if (instruction->opcode() == HloOpcode::kParameter) {
std::unique_ptr<CallGraph> call_graph =
CallGraph::Build(instruction->GetModule());
auto callers = call_graph->GetComputationCallers(instruction->parent());
for (HloInstruction* caller : callers) {
result.push_back(
{caller->mutable_operand(instruction->parameter_number()),
instruction_and_shape_index.shape_index});
}
} else if (instruction->opcode() == HloOpcode::kDynamicSlice) {
result.push_back({instruction->mutable_operand(0),
instruction_and_shape_index.shape_index});
} else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
result.push_back({instruction->mutable_operand(0),
instruction_and_shape_index.shape_index});
} else if (instruction->opcode() == HloOpcode::kWhile) {
HloComputation* while_body_computation = instruction->while_body();
result.push_back({while_body_computation->root_instruction(),
instruction_and_shape_index.shape_index});
} else {
CHECK(instruction->operand_count() == 1) << absl::StreamFormat(
"Expecting instruction %s to have 1 operand, but it has %d.",
instruction->name(), instruction->operand_count());
result.push_back({instruction->mutable_operand(0),
instruction_and_shape_index.shape_index});
}
return result;
}

bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) {
static constexpr std::array allowed_opcodes = {
HloOpcode::kGetTupleElement,
HloOpcode::kBitcast,
HloOpcode::kTuple,
HloOpcode::kCall,
HloOpcode::kWhile,
HloOpcode::kParameter,
HloOpcode::kOptimizationBarrier,
HloOpcode::kAsyncStart,
HloOpcode::kAsyncDone,
HloOpcode::kCustomCall};
return absl::c_linear_search(allowed_opcodes, instruction->opcode());
}

bool operator==(const InstructionAndShapeIndex& lhs,
const InstructionAndShapeIndex& rhs) {
return lhs.instruction == rhs.instruction &&
lhs.shape_index == rhs.shape_index;
}

std::string InstructionAndShapeIndex::ToString() const {
return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(),
shape_index.ToString());
}

} // namespace host_offload_utils
} // namespace xla
Loading

0 comments on commit aa23400

Please sign in to comment.