Skip to content

Commit

Permalink
[XLA:CPU] Remove special case code in ElementalKernelEmitter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720154411
  • Loading branch information
WillFroom authored and Google-ML-Automation committed Jan 27, 2025
1 parent db8d9e1 commit e6a0517
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 103 deletions.
4 changes: 2 additions & 2 deletions xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:hlo_module_config",
"//xla/service/llvm_ir:ir_array",
"//xla/service/llvm_ir:llvm_util",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -232,7 +234,6 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -296,7 +297,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:JITLink",
"@llvm-project//llvm:ir_headers",
],
)
Expand Down
58 changes: 20 additions & 38 deletions xla/backends/cpu/codegen/elemental_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -66,18 +65,6 @@ namespace xla::cpu {

namespace {

KernelApiIrBuilder::Options KernelApiIrBuilderOptionsFromHloModuleConfig(
const HloModule* hlo_module) {
if (hlo_module == nullptr) {
return {true, 256};
}

const HloModuleConfig& config = hlo_module->config();
return KernelApiIrBuilder::Options{
config.debug_options().xla_llvm_enable_invariant_load_metadata(),
config.debug_options().xla_cpu_prefer_vector_width()};
}

struct ParallelConfig {
std::vector<int64_t> outer_dimension_partitions;
};
Expand Down Expand Up @@ -211,46 +198,42 @@ ComputationsTransitivelyContainCustomCall(const HloInstruction* instr) {

} // namespace

ElementalKernelEmitter::ElementalKernelEmitter(const HloInstruction* instr)
: ElementalKernelEmitter(instr, nullptr, nullptr) {}

ElementalKernelEmitter::ElementalKernelEmitter(
const HloInstruction* instr, const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine)
: instr_(instr),
buffer_assignment_(buffer_assignment),
target_machine_(target_machine),
context_(std::make_unique<llvm::LLVMContext>()),
kernel_api_ir_builder_(
*context_.getContext(),
KernelApiIrBuilderOptionsFromHloModuleConfig(instr_->GetModule())) {}
target_machine_(target_machine) {}

absl::StatusOr<KernelDefinition>
ElementalKernelEmitter::EmitKernelDefinition() {
VLOG(2) << "Emit elemental host kernel: " << instr_->name();

llvm::LLVMContext& ctx = *context_.getContext();
auto ctx = std::make_unique<llvm::LLVMContext>();

// A module identifier (prefix) for emitted LLVM modules.
// (Module must be prefixed with this to ensure the cpu_compiler gives correct
// name to the dumped IR file)
static constexpr absl::string_view kXlaModuleIdentifier = "__compute_module";
auto module = std::make_unique<llvm::Module>(
absl::StrCat(kXlaModuleIdentifier, "_", instr_->name(),
"_elemental_kernel_module"),
ctx);
const HloModule* hlo_module = instr_->GetModule();
if (hlo_module == nullptr) {
return Internal("HloModule is null");
}

KernelApiIrBuilder kernel_api_ir_builder(
*ctx,
KernelApiIrBuilder::Options::FromHloModuleConfig(hlo_module->config()));

std::unique_ptr<llvm::Module> llvm_module = KernelApiIrBuilder::CreateModule(
absl::StrCat(instr_->name(), "_elemental_kernel_module"), *ctx);

TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype,
kernel_api_ir_builder_.EmitKernelPrototype(
*module, instr_, buffer_assignment_, "_kernel"));
kernel_api_ir_builder.EmitKernelPrototype(
*llvm_module, instr_, buffer_assignment_, "_kernel"));

llvm::IRBuilder<> ir_builder(ctx);
llvm::IRBuilder<> ir_builder(*ctx);
ir_builder.SetInsertPoint(
kernel_prototype.function->getEntryBlock().getTerminator());

TF_ASSIGN_OR_RETURN(
CpuElementalIrEmitter::ThreadLocalCallCallback thread_local_call_fn,
ThreadLocalCallbackFactory(ir_builder, *module));
ThreadLocalCallbackFactory(ir_builder, *llvm_module));

CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (int64_t i = 0; i < instr_->operand_count(); ++i) {
Expand All @@ -261,12 +244,11 @@ ElementalKernelEmitter::EmitKernelDefinition() {
};
}

const HloModule* hlo_module = instr_->GetModule();
bool enable_fast_min_max =
hlo_module
? hlo_module->config().debug_options().xla_cpu_enable_fast_min_max()
: true;
CpuElementalIrEmitter elemental_ir_emitter(module.get(), &ir_builder,
CpuElementalIrEmitter elemental_ir_emitter(llvm_module.get(), &ir_builder,
std::move(thread_local_call_fn),
true, enable_fast_min_max);

Expand All @@ -277,8 +259,8 @@ ElementalKernelEmitter::EmitKernelDefinition() {
EmitElementalLoops(ir_builder, instr_, kernel_prototype,
element_generator));

auto source =
std::make_unique<LlvmIrKernelSource>(context_, std::move(module));
auto source = std::make_unique<LlvmIrKernelSource>(std::move(ctx),
std::move(llvm_module));

KernelSpec spec(kernel_prototype.function->getName(), thread_dims,
std::move(kernel_prototype.buffer_uses));
Expand Down
7 changes: 0 additions & 7 deletions xla/backends/cpu/codegen/elemental_kernel_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_KERNEL_EMITTER_H_

#include "absl/status/statusor.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
Expand All @@ -34,8 +33,6 @@ namespace xla::cpu {

class ElementalKernelEmitter final : public KernelEmitter {
public:
explicit ElementalKernelEmitter(const HloInstruction* instr);

ElementalKernelEmitter(const HloInstruction* instr,
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);
Expand All @@ -62,10 +59,6 @@ class ElementalKernelEmitter final : public KernelEmitter {

const BufferAssignment* buffer_assignment_ = nullptr;
const TargetMachineFeatures* target_machine_ = nullptr;

llvm::orc::ThreadSafeContext context_;

KernelApiIrBuilder kernel_api_ir_builder_;
};

} // namespace xla::cpu
Expand Down
54 changes: 26 additions & 28 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -48,13 +49,14 @@ limitations under the License.
#include "xla/cpu_function_runtime.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"

namespace xla::cpu {

Expand Down Expand Up @@ -86,10 +88,6 @@ class MemoryDependencyAnalyzer {

// Returns alias scope for the given buffer slice.
llvm::MDNode* GetAliasScope(BufferAllocation::Slice slice) {
if (slice.allocation() == nullptr) {
return nullptr;
}

auto it = alias_scopes_.find(slice);
return it == alias_scopes_.end() ? nullptr
: llvm::MDNode::get(context_, it->second);
Expand All @@ -107,9 +105,6 @@ class MemoryDependencyAnalyzer {
};

bool ResultContainsSlice(BufferAllocation::Slice slice) {
if (slice.allocation() == nullptr) {
return false;
}
return result_slices_.contains(slice);
}

Expand Down Expand Up @@ -245,10 +240,6 @@ absl::Status VerifyKernelParameters(
absl::StatusOr<BufferAllocation::Slice> GetUniqueSlice(
const BufferAssignment* buffer_assignment,
const HloInstruction* instruction, const ShapeIndex& index) {
if (buffer_assignment == nullptr) {
return BufferAllocation::Slice{};
}

return buffer_assignment->GetUniqueSlice(instruction, index);
}

Expand Down Expand Up @@ -285,6 +276,13 @@ GetKernelResultsParameters(const HloInstruction* instruction,

} // namespace

auto KernelApiIrBuilder::Options::FromHloModuleConfig(
const HloModuleConfig& config) -> Options {
return KernelApiIrBuilder::Options{
config.debug_options().xla_llvm_enable_invariant_load_metadata(),
config.debug_options().xla_cpu_prefer_vector_width()};
}

KernelApiIrBuilder::KernelApiIrBuilder(llvm::LLVMContext& context,
Options options)
: context_(context), options_(std::move(options)) {
Expand All @@ -304,15 +302,14 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
TF_ASSIGN_OR_RETURN(std::vector<KernelParameter> results,
GetKernelResultsParameters(instr, buffer_assignment));

bool compute_alias_metadata = buffer_assignment != nullptr;
return EmitKernelPrototype(module, absl::StrCat(instr->name(), suffix),
arguments, results, compute_alias_metadata);
arguments, results);
}

auto KernelApiIrBuilder::EmitKernelPrototype(
llvm::Module& module, absl::string_view name,
absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results, bool compute_alias_metadata)
absl::Span<const KernelParameter> results)
-> absl::StatusOr<KernelPrototype> {
CHECK(&module.getContext() == &context_) << "Module context mismatch";

Expand All @@ -328,13 +325,9 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
<< result.slice.ToString();
}

if (compute_alias_metadata) {
TF_RETURN_IF_ERROR(VerifyKernelParameters(arguments, results));
}
TF_RETURN_IF_ERROR(VerifyKernelParameters(arguments, results));

MemoryDependencyAnalyzer memory_dependency_analyzer(
context_, name,
compute_alias_metadata ? results : absl::Span<const KernelParameter>{});
MemoryDependencyAnalyzer memory_dependency_analyzer(context_, name, results);

llvm::IRBuilder<> b(context_);

Expand Down Expand Up @@ -401,13 +394,11 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(context_)));

absl::InlinedVector<BufferUse, 8> buffer_uses;
if (compute_alias_metadata) {
for (const KernelParameter& argument : arguments) {
buffer_uses.push_back(BufferUse::Read(argument.slice));
}
for (const KernelParameter& result : results) {
buffer_uses.push_back(BufferUse::Write(result.slice));
}
for (const KernelParameter& argument : arguments) {
buffer_uses.push_back(BufferUse::Read(argument.slice));
}
for (const KernelParameter& result : results) {
buffer_uses.push_back(BufferUse::Write(result.slice));
}

return KernelPrototype{function,
Expand All @@ -420,6 +411,13 @@ auto KernelApiIrBuilder::EmitKernelPrototype(
std::move(buffer_uses)};
}

std::unique_ptr<llvm::Module> KernelApiIrBuilder::CreateModule(
absl::string_view name, llvm::LLVMContext& context) {
constexpr absl::string_view kXlaModuleIdentifier = "__compute_module";
return std::make_unique<llvm::Module>(
absl::StrCat(kXlaModuleIdentifier, "_", name), context);
}

auto KernelApiIrBuilder::EmitKernelThreadDims(llvm::IRBuilderBase& builder,
llvm::Value* call_frame)
-> ThreadDims {
Expand Down
14 changes: 11 additions & 3 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_BACKENDS_CPU_CODEGEN_KERNEL_API_IR_BUILDER_H_

#include <cstdint>
#include <memory>
#include <vector>

#include "absl/container/flat_hash_set.h"
Expand All @@ -31,6 +32,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/shape.h"

Expand All @@ -41,6 +43,8 @@ class KernelApiIrBuilder {
struct Options {
bool enable_invariant_load_metadata;
int32_t prefer_vector_width;

static Options FromHloModuleConfig(const HloModuleConfig& config);
};

// Thread dimensions of the kernel invocation.
Expand Down Expand Up @@ -90,7 +94,7 @@ class KernelApiIrBuilder {
absl::InlinedVector<BufferUse, 8> buffer_uses;
};

KernelApiIrBuilder(llvm::LLVMContext& context_, Options options);
KernelApiIrBuilder(llvm::LLVMContext& context, Options options);

// Emits a kernel prototype for the given HLO instruction.
// buffer_assignment may be null, in which case we will not compute alias
Expand All @@ -102,8 +106,12 @@ class KernelApiIrBuilder {
absl::StatusOr<KernelPrototype> EmitKernelPrototype(
llvm::Module& module, absl::string_view name,
absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results,
bool compute_alias_metadata = true);
absl::Span<const KernelParameter> results);

// Create a module with the given name, the name is given a prefix that is
// specific to XLA and relied on further down the pipeline.
static std::unique_ptr<llvm::Module> CreateModule(absl::string_view name,
llvm::LLVMContext& context);

private:
ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder,
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pytype_strict_library(
testonly = 1,
srcs = [
"__init__.py",
"utilities.py",
],
srcs_version = "PY3",
deps = [
Expand Down
Loading

0 comments on commit e6a0517

Please sign in to comment.