Skip to content

Commit

Permalink
[xla:cpu] Use iterators for executing thunks sequentially
Browse files Browse the repository at this point in the history
This saves one register and a few instructions in the hot loop.

name                                     old time/op          new time/op          delta
BM_SelectAndScatterF32/128/process_time   377µs ± 4%           371µs ± 2%  -1.73%
BM_SelectAndScatterF32/256/process_time  1.55ms ± 4%          1.52ms ± 2%  -1.98%
BM_SelectAndScatterF32/512/process_time  6.64ms ± 4%          6.58ms ± 4%  -0.93%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657602607
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jul 31, 2024
1 parent 5843411 commit fba5b1f
Show file tree
Hide file tree
Showing 31 changed files with 168 additions and 129 deletions.
15 changes: 0 additions & 15 deletions third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl
index 76a13a4..9345d8d 100644
--- i/third_party/llvm/workspace.bzl
+++ w/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4"
- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8"
+ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
+ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce"
SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c"
SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"

tf_http_archive(
name = "shardy",
Expand Down
15 changes: 0 additions & 15 deletions third_party/xla/third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl
index 76a13a4..9345d8d 100644
--- i/third_party/llvm/workspace.bzl
+++ w/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4"
- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8"
+ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575"
+ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/xla/third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce"
SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c"
SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a"
SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2"

tf_http_archive(
name = "shardy",
Expand Down
24 changes: 12 additions & 12 deletions third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector,
AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp(
graph_trace->deviceId, graph_trace->correlationId);
collector.receive(CuptiTracerEvent{
.type = CuptiTracerEventType::CudaGraph,
.source = CuptiTracerEventSource::Activity,
.name = absl::StrCat("CudaGraphExec:", graph_trace->graphId),
.annotation = info.annotation,
.nvtx_range = info.nvtx_range,
.start_time_ns = graph_trace->start,
.end_time_ns = graph_trace->end,
.device_id = graph_trace->deviceId,
.correlation_id = graph_trace->correlationId,
.context_id = graph_trace->contextId,
.stream_id = graph_trace->streamId,
.graph_id = graph_trace->graphId,
/* .type = */ CuptiTracerEventType::CudaGraph,
/* .source = */ CuptiTracerEventSource::Activity,
/* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId),
/* .annotation = */ info.annotation,
/* .nvtx_range = */ info.nvtx_range,
/* .start_time_ns = */ graph_trace->start,
/* .end_time_ns = */ graph_trace->end,
/* .device_id = */ graph_trace->deviceId,
/* .correlation_id = */ graph_trace->correlationId,
/* .context_id = */ graph_trace->contextId,
/* .stream_id = */ graph_trace->streamId,
/* .graph_id = */ graph_trace->graphId,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct MemcpyDetails {
int8_t dst_mem_kind;

// ID of the hardware channel on which this operation ran.
uint32_t channel_id = -1;
uint32_t channel_id = static_cast<uint32_t>(-1);
// CUpti_ChannelType of the channel above.
int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID
};
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ std::optional<DynamicOrStaticInteger> EvaluateWhileLoopParamInitValue(

namespace internal {

#if !defined(_MSC_VER)
constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error) {
auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,11 @@ enum class EvalErrorDetail : uint32_t {
kDynamicValueDependence = 0,
};

#if defined(_MSC_VER)
extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#else
extern const absl::string_view kEvalErrorDetailUrl;
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error);

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize(
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size));

PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{
.serialized = args->layout->layout->Serialize()};
/* .serialized = */ args->layout->layout->Serialize()};
args->serialized_layout = s_layout;
args->serialized_bytes = s_layout->serialized.data();
args->serialized_bytes_size = s_layout->serialized.size();
Expand Down
14 changes: 8 additions & 6 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options,
#endif
}

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(
#if TENSORFLOW_USE_ROCM
RocmName(),
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(RocmName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#else
CudaName(),
#endif
std::make_unique<StreamExecutorGpuCompiler>());
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(CudaName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#endif
} // namespace xla
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/cpu/runtime/conv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void EigenConv2DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
Eigen::Aligned>
input(lhs, input_batch, input_x, input_y, input_channels);
Expand Down Expand Up @@ -129,7 +129,7 @@ void EigenConv3DImpl(
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
using ConstTType =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
Eigen::Aligned>;
Expand Down Expand Up @@ -223,7 +223,7 @@ void EigenConv3DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand All @@ -249,7 +249,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand Down
10 changes: 7 additions & 3 deletions third_party/xla/xla/service/cpu/runtime/thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ std::string_view Thunk::KindToString(Kind kind) {
return "while";
}
}
Thunk::Thunk(Kind kind, Info info)
: kind_(kind),
info_(std::move(info)),
ok_event_(OkExecuteEventSingleton()) {}

absl::StatusOr<Thunk::CollectiveExecuteParams>
Thunk::CollectiveExecuteParams::Create(
Expand Down Expand Up @@ -150,13 +154,13 @@ Thunk::CustomCallExecuteParams::CustomCallExecuteParams(
allocator(allocator),
ffi_execution_context(ffi_execution_context) {}

const tsl::AsyncValueOwningRef<Thunk::ExecuteEvent>* Thunk::OkEvent() {
static tsl::AsyncValueOwningRef<ExecuteEvent>* owner = [] {
tsl::AsyncValueRef<Thunk::ExecuteEvent> Thunk::OkExecuteEventSingleton() {
static tsl::AsyncValueOwningRef<ExecuteEvent>* singleton = [] {
auto* storage = new tsl::internal::AsyncValueStorage<ExecuteEvent>();
return new tsl::AsyncValueOwningRef<ExecuteEvent>(
tsl::MakeAvailableAsyncValueRef<ExecuteEvent>(*storage));
}();
return owner;
return singleton->AsRef();
}

Thunk::ExecuteState::ExecuteState(int64_t num_tasks)
Expand Down
26 changes: 14 additions & 12 deletions third_party/xla/xla/service/cpu/runtime/thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Thunk {
using Task = std::function<void()>;
using TaskRunner = absl::AnyInvocable<void(Task)>;

Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)) {}
Thunk(Kind kind, Info info);

Thunk(const Thunk&) = delete;
Thunk& operator=(const Thunk&) = delete;
Expand Down Expand Up @@ -286,18 +286,20 @@ class Thunk {
// An execute event that becomes ready when all tasks are completed.
using ExecuteEvent = tsl::Chain;

// Returns non-reference-counted async value ref for thunks executed in the
// caller thread to avoid reference counting overhead.
static tsl::AsyncValueRef<ExecuteEvent> OkExecuteEvent() {
return OkEvent()->AsRef();
}
// Returns non-reference-counted async value ref in constructed state.
// Returned async value is a per-process singleton stored in a storage with a
// static duration, and can be safely compared using pointer equality.
static tsl::AsyncValueRef<ExecuteEvent> OkExecuteEventSingleton();

// Returns `OkExecuteEventSingleton()` cached by this thunk instance.
tsl::AsyncValueRef<ExecuteEvent> OkExecuteEvent() const { return ok_event_; }

static bool IsOkExecuteEvent(tsl::AsyncValuePtr<ExecuteEvent> event) {
return event == OkEvent()->AsPtr();
bool IsOkExecuteEvent(const tsl::AsyncValueRef<ExecuteEvent>& event) const {
return event == ok_event_;
}

static bool IsOkExecuteEvent(const tsl::AsyncValueRef<ExecuteEvent>& event) {
return IsOkExecuteEvent(event.AsPtr());
bool IsOkExecuteEvent(tsl::AsyncValuePtr<ExecuteEvent> event) const {
return event == ok_event_.AsPtr();
}

// Thunk execution must be asynchronous and never block the caller thread,
Expand Down Expand Up @@ -339,10 +341,10 @@ class Thunk {
}

private:
static const tsl::AsyncValueOwningRef<Thunk::ExecuteEvent>* OkEvent();

Kind kind_;
Info info_;

tsl::AsyncValueRef<ExecuteEvent> ok_event_;
};

std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
Expand Down
41 changes: 21 additions & 20 deletions third_party/xla/xla/service/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence,
const ThunkExecutor::Options& options)
: thunk_sequence_(std::move(thunk_sequence)),
options_(options),
num_thunks_(thunk_sequence_.size()),
nodes_defs_(std::move(nodes_defs)),
is_sequential_(true) {
for (NodeId i = 0; i < nodes_defs_.size(); ++i) {
Expand Down Expand Up @@ -143,10 +144,10 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor,
tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
const Thunk::ExecuteParams& params) {
// Short-circuit execution of trivial thunk sequences.
if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) {
return Thunk::OkExecuteEvent();
if (ABSL_PREDICT_FALSE(num_thunks_ == 0)) {
return Thunk::OkExecuteEventSingleton();
}
if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) {
if (ABSL_PREDICT_FALSE(num_thunks_ == 1)) {
return thunk_sequence_[0]->Execute(params);
}

Expand Down Expand Up @@ -176,24 +177,24 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(

tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent>
ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {
for (int64_t i = 0; i < thunk_sequence_.size(); ++i) {
Thunk& thunk = *thunk_sequence_[i];
for (auto it = thunk_sequence_.begin(); it != thunk_sequence_.end(); ++it) {
Thunk& thunk = **it;
auto execute_event = thunk.Execute(params);

// Fast path for thunks executed inline and returned OkExecuteEvent.
if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) {
if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) {
continue;
}

// If thunk execution is not completed yet, attach a continuation to
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
auto event = tsl::MakeConstructedAsyncValueRef<ExecuteEvent>();
execute_event.AndThen([this, &params, i, event](absl::Status status) {
execute_event.AndThen([this, &params, it, event](absl::Status status) {
if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else {
ResumeExecuteSequential(i + 1, params, std::move(event));
ResumeExecuteSequential(it + 1, params, std::move(event));
}
});
return event;
Expand All @@ -207,30 +208,30 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {

// If we got to the end of the sequence it means that all thunks have
// succeeded.
return Thunk::OkExecuteEvent();
return Thunk::OkExecuteEventSingleton();
}

void ThunkExecutor::ResumeExecuteSequential(
int64_t index, const Thunk::ExecuteParams& params,
ThunkIterator it, const Thunk::ExecuteParams& params,
tsl::AsyncValueRef<ExecuteEvent> event) {
for (int64_t i = index; i < thunk_sequence_.size(); ++i) {
Thunk& thunk = *thunk_sequence_[i];
for (; it != thunk_sequence_.end(); ++it) {
Thunk& thunk = **it;
auto execute_event = thunk.Execute(params);

// Fast path for thunks executed inline and returned OkExecuteEvent.
if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) {
if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) {
continue;
}

// If thunk execution is not completed yet, attach a continuation to
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
execute_event.AndThen(
[this, &params, i, event = std::move(event)](absl::Status status) {
[this, &params, it, event = std::move(event)](absl::Status status) {
if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else {
ResumeExecuteSequential(i + 1, params, std::move(event));
ResumeExecuteSequential(it + 1, params, std::move(event));
}
});
return;
Expand Down Expand Up @@ -281,7 +282,7 @@ void ThunkExecutor::Execute(ExecuteState* state,
Thunk& thunk = *state->executor->thunk_sequence_[id];
tsl::AsyncValueRef<ExecuteEvent> execute_event =
ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed))
? Thunk::OkExecuteEvent()
? Thunk::OkExecuteEventSingleton()
: thunk.Execute(params);

if (ABSL_PREDICT_TRUE(execute_event.IsAvailable())) {
Expand Down Expand Up @@ -471,19 +472,19 @@ int64_t ThunkExecutor::TransitiveReduction() {

std::string ThunkExecutor::ToString() const {
std::string str = absl::StrFormat(
"ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d",
thunk_sequence_.size(), source_.size(), sink_.size());
"ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", num_thunks_,
source_.size(), sink_.size());

// Collect names of `in_edges`.
std::vector<std::vector<std::string>> in_edges(thunk_sequence_.size());
std::vector<std::vector<std::string>> in_edges(num_thunks_);
for (const auto& node_def : nodes_defs_) {
for (NodeId in_edge : node_def.in_edges) {
in_edges[node_def.id].push_back(thunk_sequence_[in_edge]->info().op_name);
}
}

// Print thunks with a list of their dependencies;
for (NodeId i = 0; i < thunk_sequence_.size(); ++i) {
for (NodeId i = 0; i < num_thunks_; ++i) {
const Thunk& thunk = *thunk_sequence_[i];
bool is_source = absl::c_find(source_, i) != source_.end();
bool is_sink = absl::c_find(sink_, i) != sink_.end();
Expand Down
Loading

0 comments on commit fba5b1f

Please sign in to comment.