Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Add Lifo ready queue for completeness #21905

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 65 additions & 23 deletions xla/backends/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <sys/types.h>

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -220,12 +221,20 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
// This also works for thunks with nested thunk executors (i.e., WhileThunk),
// as launching nested thunk sequence must not reduce the available
// concurrency for the other thunks executing in parallel.
if (options_.use_priority_ready_queue) {
Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_),
/*lock=*/nullptr);
} else {
Execute(state.get(), params, FifoReadyQueue(source_),
/*lock=*/nullptr);
auto execute = [&](auto ready_queue) {
Execute(state.get(), params, std::move(ready_queue), /*lock=*/nullptr);
};

switch (options_.ready_queue_type) {
case Options::ReadyQueueType::kFifo:
execute(FifoReadyQueue(source_));
break;
case Options::ReadyQueueType::kLifo:
execute(LifoReadyQueue(source_));
break;
case Options::ReadyQueueType::kPriority:
execute(PriorityReadyQueue(nodes_defs_, source_));
break;
}

// If execution already completed (all kernels executed in the caller thread),
Expand Down Expand Up @@ -294,7 +303,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (!runner || runner->current_worker_id()) {
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already running
// on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
Expand Down Expand Up @@ -334,23 +343,23 @@ void ThunkExecutor::ResumeExecuteSequential(
// If thunk execution is not completed yet, attach a continuation to
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
execute_event.AndThen(
[this, &params, it, event = std::move(event)](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;
execute_event.AndThen([this, &params, it,
event = std::move(event)](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (!runner || runner->current_worker_id()) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
}
if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
}
});
return;
}

Expand Down Expand Up @@ -443,7 +452,7 @@ void ThunkExecutor::Execute(ExecuteState* state,
}

Thunk::TaskRunner* runner = state->runner;
if (!runner || runner->current_worker_id()) {
if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
state->executor->Execute(state, params, std::move(ready_queue),
Expand Down Expand Up @@ -740,6 +749,39 @@ ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const {
return FifoReadyQueue(absl::Span<const NodeId>());
}

ThunkExecutor::LifoReadyQueue::LifoReadyQueue(
absl::Span<const NodeId> ready_nodes)
: queue_(ready_nodes.begin(), ready_nodes.end()) {}

void ThunkExecutor::LifoReadyQueue::Push(NodeId id) { queue_.push_back(id); }

ThunkExecutor::NodeId ThunkExecutor::LifoReadyQueue::Pop() {
DCHECK(!Empty()) << "Queue must not be empty";
NodeId id = queue_.back();
queue_.pop_back();
return id;
}

ThunkExecutor::LifoReadyQueue ThunkExecutor::LifoReadyQueue::PopHalf() {
DCHECK(!Empty()) << "Queue must not be empty";
auto mid = Size() / 2 + 1;
LifoReadyQueue popped(
absl::MakeConstSpan(queue_.begin(), queue_.begin() + mid));

std::move(queue_.begin() + mid, queue_.end(), queue_.begin());
queue_.resize(queue_.size() - mid);
return popped;
}

size_t ThunkExecutor::LifoReadyQueue::Size() const { return queue_.size(); }

bool ThunkExecutor::LifoReadyQueue::Empty() const { return queue_.empty(); }

ThunkExecutor::LifoReadyQueue
ThunkExecutor::LifoReadyQueue::CreateEmptyReadyQueue() const {
return LifoReadyQueue(absl::Span<const NodeId>());
}

ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue(
absl::Span<const NodeDef> nodes_defs, absl::Span<const NodeId> ready_nodes)
: nodes_defs_(nodes_defs),
Expand Down
26 changes: 23 additions & 3 deletions xla/backends/cpu/runtime/thunk_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ namespace internal {
// Clang does not allow defining a nested struct with member initializer, as
// a workaround we define a struct in internal namespace and create an alias.
struct ThunkExecutorOptions {
enum class ReadyQueueType { kFifo, kLifo, kPriority };

// If all thunks in a sequence use buffers of size less than or equal to the
// given threshold, we mark execution as sequential, as concurrency overheads
// will likely dominate the overall execution time.
Expand All @@ -54,9 +56,8 @@ struct ThunkExecutorOptions {
// the overall execution time.
size_t execute_sequential_num_thunks_threshold = 8;

// Use priority ready queue to execute nodes according to their priority. By
// default we use FIFO ready queue.
bool use_priority_ready_queue = false;
// The type of a queue for ready thunks.
ReadyQueueType ready_queue_type = ReadyQueueType::kFifo;
};
} // namespace internal

Expand Down Expand Up @@ -146,6 +147,25 @@ class ThunkExecutor {
size_t head_ = 0;
};

// A ready queue that executes nodes in LIFO order.
class LifoReadyQueue {
public:
explicit LifoReadyQueue(absl::Span<const NodeId> ready_nodes);

void Push(NodeId id);

NodeId Pop();
LifoReadyQueue PopHalf();

size_t Size() const;
bool Empty() const;

LifoReadyQueue CreateEmptyReadyQueue() const;

private:
absl::InlinedVector<NodeId, 8> queue_;
};

// A ready queue that executes nodes sorted by NodeDef priority.
class PriorityReadyQueue {
public:
Expand Down
80 changes: 69 additions & 11 deletions xla/backends/cpu/runtime/thunk_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {
queue.Push(2);
queue.Push(3);

EXPECT_EQ(queue.Size(), 3);
ASSERT_EQ(queue.Size(), 3);

EXPECT_EQ(queue.Pop(), 1);
EXPECT_EQ(queue.Pop(), 2);
EXPECT_EQ(queue.Pop(), 3);

EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);
ASSERT_EQ(queue.Size(), 0);

// Prepare queue for PopHalf test case.
queue.Push(1);
Expand All @@ -267,16 +267,16 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {

// Pop half of the queue.
ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf();
EXPECT_EQ(half0.Size(), 2);
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 3);

// Check that the rest is still in the queue.
EXPECT_EQ(queue.Size(), 1);
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf();
EXPECT_EQ(half1.Size(), 1);
ASSERT_EQ(half1.Size(), 1);

// Check that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);
Expand All @@ -292,11 +292,69 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {

// Check that PopHalf returns 2 last nodes.
ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf();
EXPECT_EQ(half2.Size(), 2);
ASSERT_EQ(half2.Size(), 2);
EXPECT_EQ(half2.Pop(), 4);
EXPECT_EQ(half2.Pop(), 5);
}

TEST(ThunkExecutorTest, LifoReadyQueueTest) {
ThunkExecutor::LifoReadyQueue queue({});

// Check basic queue properties.
EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);

queue.Push(1);
queue.Push(2);
queue.Push(3);

ASSERT_EQ(queue.Size(), 3);

EXPECT_EQ(queue.Pop(), 3);
EXPECT_EQ(queue.Pop(), 2);
EXPECT_EQ(queue.Pop(), 1);

EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);

// Prepare queue for PopHalf test case.
queue.Push(1);
queue.Push(2);
queue.Push(3);

// Pop half of the queue.
ThunkExecutor::LifoReadyQueue half0 = queue.PopHalf();
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 1);

// Check that the rest is still in the queue.
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::LifoReadyQueue half1 = queue.PopHalf();
ASSERT_EQ(half1.Size(), 1);

// ASSERT_EQ that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);

// Add 5 elements to test Pop followed by PopHalf.
queue.Push(1);
queue.Push(2);
queue.Push(3);
queue.Push(4);
queue.Push(5);

EXPECT_EQ(queue.Pop(), 5);

// Check that PopHalf returns first 2 nodes.
ThunkExecutor::LifoReadyQueue half2 = queue.PopHalf();
ASSERT_EQ(half2.Size(), 3);
EXPECT_EQ(half2.Pop(), 3);
EXPECT_EQ(half2.Pop(), 2);
EXPECT_EQ(half2.Pop(), 1);
}

TEST(ThunkExecutorTest, PriorityReadyQueueTest) {
std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
for (size_t i = 0; i < nodes_defs.size(); ++i) {
Expand Down Expand Up @@ -326,20 +384,20 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) {

// Pop half of the queue.
ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf();
EXPECT_EQ(half0.Size(), 2);
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 1);

// Check that the rest is still in the queue.
EXPECT_EQ(queue.Size(), 1);
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf();
EXPECT_EQ(half1.Size(), 1);
ASSERT_EQ(half1.Size(), 1);
EXPECT_EQ(half1.Pop(), 3);

// Check that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);
ASSERT_EQ(queue.Size(), 0);

// Add 5 elements to test Pop followed by PopHalf.
queue.Push(4);
Expand All @@ -352,7 +410,7 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) {

// Check that PopHalf returns 2 last nodes.
ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf();
EXPECT_EQ(half2.Size(), 2);
ASSERT_EQ(half2.Size(), 2);
EXPECT_EQ(half2.Pop(), 2);
EXPECT_EQ(half2.Pop(), 1);
}
Expand Down
Loading