-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[easy] [XLA][HostOffloading] Separate utils used in Host Offloader
PiperOrigin-RevId: 663034665
- Loading branch information
1 parent
2e8be6d
commit aa23400
Showing
6 changed files
with
539 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/service/host_offload_utils.h" | ||
|
||
#include <array> | ||
#include <cstdint> | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "absl/algorithm/container.h" | ||
#include "absl/log/check.h" | ||
#include "absl/log/log.h" | ||
#include "absl/status/status.h" | ||
#include "absl/status/statusor.h" | ||
#include "absl/strings/str_format.h" | ||
#include "xla/hlo/ir/hlo_computation.h" | ||
#include "xla/hlo/ir/hlo_instruction.h" | ||
#include "xla/hlo/ir/hlo_opcode.h" | ||
#include "xla/service/call_graph.h" | ||
#include "xla/service/host_memory_offload_annotations.h" | ||
#include "xla/shape_util.h" | ||
#include "xla/util.h" | ||
|
||
namespace xla { | ||
namespace host_offload_utils { | ||
|
||
namespace { | ||
|
||
using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; | ||
using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget; | ||
|
||
bool CustomCallReusesBuffer(const HloInstruction* custom_call, | ||
int64_t operand_index) { | ||
if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget || | ||
custom_call->custom_call_target() == kMoveToHostCustomCallTarget) { | ||
// Does not define a new buffer. | ||
return true; | ||
} | ||
// Check the custom call's output_to_operand_aliasing. | ||
const std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>& | ||
aliases = custom_call->output_operand_aliasing(); | ||
for (const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>& alias : | ||
aliases) { | ||
int64_t alias_operand_index = alias.second.first; | ||
if (alias_operand_index == operand_index) { | ||
// This operand aliases with the output. | ||
return true; | ||
} | ||
} | ||
// By default, assume custom calls define new buffers. | ||
return false; | ||
} | ||
|
||
} // namespace | ||
|
||
absl::StatusOr<std::vector<InstructionAndShapeIndex>> GetSuccessors( | ||
const InstructionAndShapeIndex& instruction_and_shape_index) { | ||
std::vector<InstructionAndShapeIndex> result; | ||
HloInstruction* instruction = instruction_and_shape_index.instruction; | ||
if (instruction->IsRoot()) { | ||
// Successor of the root is the call instruction(s). | ||
std::unique_ptr<CallGraph> call_graph = | ||
CallGraph::Build(instruction->GetModule()); | ||
auto callers = call_graph->GetComputationCallers(instruction->parent()); | ||
for (HloInstruction* caller : callers) { | ||
result.push_back({caller, instruction_and_shape_index.shape_index}); | ||
} | ||
} | ||
for (HloInstruction* user : instruction->users()) { | ||
if (user->opcode() == HloOpcode::kTuple) { | ||
auto operand_indices = user->OperandIndices(instruction); | ||
for (const auto i : operand_indices) { | ||
auto tmp_shape_index = instruction_and_shape_index.shape_index; | ||
tmp_shape_index.push_back(i); | ||
result.push_back({user, std::move(tmp_shape_index)}); | ||
} | ||
} else if (user->opcode() == HloOpcode::kGetTupleElement) { | ||
ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index; | ||
const auto index = tmp_shape_index.front(); | ||
if (index == user->tuple_index()) { | ||
// This GTE is for the buffer we're tracking. | ||
tmp_shape_index.pop_front(); | ||
result.push_back({user, std::move(tmp_shape_index)}); | ||
} | ||
} else if (user->opcode() == HloOpcode::kCall) { | ||
auto operand_indices = user->OperandIndices(instruction); | ||
CHECK(user->called_computations().size() == 1) | ||
<< "Expect call to only have one called computation."; | ||
for (const auto i : operand_indices) { | ||
HloComputation* called_computation = | ||
user->called_computations().front(); | ||
HloInstruction* parameter_instruction = | ||
called_computation->parameter_instruction(i); | ||
result.push_back( | ||
{parameter_instruction, instruction_and_shape_index.shape_index}); | ||
} | ||
} else if (user->opcode() == HloOpcode::kWhile) { | ||
auto operand_indices = user->OperandIndices(instruction); | ||
HloComputation* while_body_computation = user->while_body(); | ||
HloComputation* while_condition_computation = user->while_condition(); | ||
for (const auto i : operand_indices) { | ||
HloInstruction* parameter_instruction = | ||
while_body_computation->parameter_instruction(i); | ||
result.push_back( | ||
{parameter_instruction, instruction_and_shape_index.shape_index}); | ||
|
||
HloInstruction* condition_instruction = | ||
while_condition_computation->parameter_instruction(i); | ||
result.push_back( | ||
{condition_instruction, instruction_and_shape_index.shape_index}); | ||
} | ||
} else if (user->opcode() == HloOpcode::kAsyncStart) { | ||
auto operand_indices = user->OperandIndices(instruction); | ||
CHECK(user->called_computations().size() == 1) | ||
<< "Expect async-start to only have one called computation."; | ||
for (const auto i : operand_indices) { | ||
HloComputation* called_computation = | ||
user->called_computations().front(); | ||
HloInstruction* parameter_instruction = | ||
called_computation->parameter_instruction(i); | ||
result.push_back( | ||
{parameter_instruction, instruction_and_shape_index.shape_index}); | ||
} | ||
} else if (user->opcode() == HloOpcode::kCustomCall) { | ||
const auto operand_indices = user->OperandIndices(instruction); | ||
// TODO(b/342650757): Rather than a boolean indicating whether the | ||
// instruction reuses the buffer, return the shape index of the output | ||
// that the operand aliases with. | ||
bool found_one = false; | ||
for (const auto i : operand_indices) { | ||
if (CustomCallReusesBuffer(user, i)) { | ||
if (found_one) { | ||
return absl::InternalError( | ||
"Found multiple operands of a custom call that reuse the same " | ||
"output buffer."); | ||
} | ||
result.push_back({user, instruction_and_shape_index.shape_index}); | ||
found_one = true; | ||
} | ||
} | ||
} else { | ||
result.push_back({user, instruction_and_shape_index.shape_index}); | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
std::vector<InstructionAndShapeIndex> GetPredecessors( | ||
const InstructionAndShapeIndex& instruction_and_shape_index) { | ||
std::vector<InstructionAndShapeIndex> result; | ||
HloInstruction* instruction = instruction_and_shape_index.instruction; | ||
if (instruction->opcode() == HloOpcode::kGetTupleElement) { | ||
const int64_t index = instruction->tuple_index(); | ||
auto tmp_shape_index = instruction_and_shape_index.shape_index; | ||
tmp_shape_index.push_front(index); | ||
result.push_back({instruction->mutable_operand(0), tmp_shape_index}); | ||
} else if (instruction->opcode() == HloOpcode::kTuple) { | ||
CHECK(!instruction_and_shape_index.shape_index.empty()) | ||
<< "Did not store an index before encountering a tuple."; | ||
auto tmp_shape_index = instruction_and_shape_index.shape_index; | ||
const int64_t index = tmp_shape_index.front(); | ||
tmp_shape_index.pop_front(); | ||
result.push_back({instruction->mutable_operand(index), tmp_shape_index}); | ||
} else if (instruction->opcode() == HloOpcode::kCall) { | ||
// Predecessor of a call is its computation's root instruction. | ||
CHECK(instruction->called_computations().size() == 1) | ||
<< "Expect call to only have one called computation."; | ||
HloComputation* called_computation = | ||
instruction->called_computations().front(); | ||
result.push_back({called_computation->root_instruction(), | ||
instruction_and_shape_index.shape_index}); | ||
} else if (instruction->opcode() == HloOpcode::kParameter) { | ||
std::unique_ptr<CallGraph> call_graph = | ||
CallGraph::Build(instruction->GetModule()); | ||
auto callers = call_graph->GetComputationCallers(instruction->parent()); | ||
for (HloInstruction* caller : callers) { | ||
result.push_back( | ||
{caller->mutable_operand(instruction->parameter_number()), | ||
instruction_and_shape_index.shape_index}); | ||
} | ||
} else if (instruction->opcode() == HloOpcode::kDynamicSlice) { | ||
result.push_back({instruction->mutable_operand(0), | ||
instruction_and_shape_index.shape_index}); | ||
} else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { | ||
result.push_back({instruction->mutable_operand(0), | ||
instruction_and_shape_index.shape_index}); | ||
} else if (instruction->opcode() == HloOpcode::kWhile) { | ||
HloComputation* while_body_computation = instruction->while_body(); | ||
result.push_back({while_body_computation->root_instruction(), | ||
instruction_and_shape_index.shape_index}); | ||
} else { | ||
CHECK(instruction->operand_count() == 1) << absl::StreamFormat( | ||
"Expecting instruction %s to have 1 operand, but it has %d.", | ||
instruction->name(), instruction->operand_count()); | ||
result.push_back({instruction->mutable_operand(0), | ||
instruction_and_shape_index.shape_index}); | ||
} | ||
return result; | ||
} | ||
|
||
bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) { | ||
static constexpr std::array allowed_opcodes = { | ||
HloOpcode::kGetTupleElement, | ||
HloOpcode::kBitcast, | ||
HloOpcode::kTuple, | ||
HloOpcode::kCall, | ||
HloOpcode::kWhile, | ||
HloOpcode::kParameter, | ||
HloOpcode::kOptimizationBarrier, | ||
HloOpcode::kAsyncStart, | ||
HloOpcode::kAsyncDone, | ||
HloOpcode::kCustomCall}; | ||
return absl::c_linear_search(allowed_opcodes, instruction->opcode()); | ||
} | ||
|
||
bool operator==(const InstructionAndShapeIndex& lhs, | ||
const InstructionAndShapeIndex& rhs) { | ||
return lhs.instruction == rhs.instruction && | ||
lhs.shape_index == rhs.shape_index; | ||
} | ||
|
||
std::string InstructionAndShapeIndex::ToString() const { | ||
return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(), | ||
shape_index.ToString()); | ||
} | ||
|
||
} // namespace host_offload_utils | ||
} // namespace xla |
Oops, something went wrong.