Skip to content

Commit

Permalink
[xla:cpu] Add Lifo ready queue for completeness
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719486268
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 27, 2025
1 parent c96f21e commit b55bcbe
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 37 deletions.
88 changes: 65 additions & 23 deletions xla/backends/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <sys/types.h>

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -220,12 +221,20 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
// This also works for thunks with nested thunk executors (i.e., WhileThunk),
// as launching nested thunk sequence must not reduce the available
// concurrency for the other thunks executing in parallel.
if (options_.use_priority_ready_queue) {
Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_),
/*lock=*/nullptr);
} else {
Execute(state.get(), params, FifoReadyQueue(source_),
/*lock=*/nullptr);
auto execute = [&](auto ready_queue) {
Execute(state.get(), params, std::move(ready_queue), /*lock=*/nullptr);
};

switch (options_.ready_queue_type) {
case Options::ReadyQueueType::kFifo:
execute(FifoReadyQueue(source_));
break;
case Options::ReadyQueueType::kLifo:
execute(LifoReadyQueue(source_));
break;
case Options::ReadyQueueType::kPriority:
execute(PriorityReadyQueue(nodes_defs_, source_));
break;
}

// If execution already completed (all kernels executed in the caller thread),
Expand Down Expand Up @@ -294,7 +303,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (!runner || runner->current_worker_id()) {
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already running
// on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
Expand Down Expand Up @@ -334,23 +343,23 @@ void ThunkExecutor::ResumeExecuteSequential(
// If thunk execution is not completed yet, attach a continuation to
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
execute_event.AndThen(
[this, &params, it, event = std::move(event)](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;
execute_event.AndThen([this, &params, it,
event = std::move(event)](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (!runner || runner->current_worker_id()) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
}
if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
}
});
return;
}

Expand Down Expand Up @@ -443,7 +452,7 @@ void ThunkExecutor::Execute(ExecuteState* state,
}

Thunk::TaskRunner* runner = state->runner;
if (!runner || runner->current_worker_id()) {
if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
state->executor->Execute(state, params, std::move(ready_queue),
Expand Down Expand Up @@ -740,6 +749,39 @@ ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const {
return FifoReadyQueue(absl::Span<const NodeId>());
}

ThunkExecutor::LifoReadyQueue::LifoReadyQueue(
absl::Span<const NodeId> ready_nodes)
: queue_(ready_nodes.begin(), ready_nodes.end()) {}

void ThunkExecutor::LifoReadyQueue::Push(NodeId id) { queue_.push_back(id); }

ThunkExecutor::NodeId ThunkExecutor::LifoReadyQueue::Pop() {
DCHECK(!Empty()) << "Queue must not be empty";
NodeId id = queue_.back();
queue_.pop_back();
return id;
}

ThunkExecutor::LifoReadyQueue ThunkExecutor::LifoReadyQueue::PopHalf() {
DCHECK(!Empty()) << "Queue must not be empty";
auto mid = Size() / 2 + 1;
LifoReadyQueue popped(
absl::MakeConstSpan(queue_.begin(), queue_.begin() + mid));

std::move(queue_.begin() + mid, queue_.end(), queue_.begin());
queue_.resize(queue_.size() - mid);
return popped;
}

size_t ThunkExecutor::LifoReadyQueue::Size() const { return queue_.size(); }

bool ThunkExecutor::LifoReadyQueue::Empty() const { return queue_.empty(); }

ThunkExecutor::LifoReadyQueue
ThunkExecutor::LifoReadyQueue::CreateEmptyReadyQueue() const {
return LifoReadyQueue(absl::Span<const NodeId>());
}

ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue(
absl::Span<const NodeDef> nodes_defs, absl::Span<const NodeId> ready_nodes)
: nodes_defs_(nodes_defs),
Expand Down
26 changes: 23 additions & 3 deletions xla/backends/cpu/runtime/thunk_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ namespace internal {
// Clang does not allow defining a nested struct with member initializer, as
// a workaround we define a struct in internal namespace and create an alias.
struct ThunkExecutorOptions {
enum class ReadyQueueType { kFifo, kLifo, kPriority };

// If all thunks in a sequence use buffers of size less than or equal to the
// given threshold, we mark execution as sequential, as concurrency overheads
// will likely dominate the overall execution time.
Expand All @@ -54,9 +56,8 @@ struct ThunkExecutorOptions {
// the overall execution time.
size_t execute_sequential_num_thunks_threshold = 8;

// Use priority ready queue to execute nodes according to their priority. By
// default we use FIFO ready queue.
bool use_priority_ready_queue = false;
// The type of a queue for ready thunks.
ReadyQueueType ready_queue_type = ReadyQueueType::kFifo;
};
} // namespace internal

Expand Down Expand Up @@ -146,6 +147,25 @@ class ThunkExecutor {
size_t head_ = 0;
};

// A ready queue that executes nodes in LIFO order.
class LifoReadyQueue {
public:
explicit LifoReadyQueue(absl::Span<const NodeId> ready_nodes);

void Push(NodeId id);

NodeId Pop();
LifoReadyQueue PopHalf();

size_t Size() const;
bool Empty() const;

LifoReadyQueue CreateEmptyReadyQueue() const;

private:
absl::InlinedVector<NodeId, 8> queue_;
};

// A ready queue that executes nodes sorted by NodeDef priority.
class PriorityReadyQueue {
public:
Expand Down
80 changes: 69 additions & 11 deletions xla/backends/cpu/runtime/thunk_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {
queue.Push(2);
queue.Push(3);

EXPECT_EQ(queue.Size(), 3);
ASSERT_EQ(queue.Size(), 3);

EXPECT_EQ(queue.Pop(), 1);
EXPECT_EQ(queue.Pop(), 2);
EXPECT_EQ(queue.Pop(), 3);

EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);
ASSERT_EQ(queue.Size(), 0);

// Prepare queue for PopHalf test case.
queue.Push(1);
Expand All @@ -267,16 +267,16 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {

// Pop half of the queue.
ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf();
EXPECT_EQ(half0.Size(), 2);
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 3);

// Check that the rest is still in the queue.
EXPECT_EQ(queue.Size(), 1);
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf();
EXPECT_EQ(half1.Size(), 1);
ASSERT_EQ(half1.Size(), 1);

// Check that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);
Expand All @@ -292,11 +292,69 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) {

// Check that PopHalf returns 2 last nodes.
ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf();
EXPECT_EQ(half2.Size(), 2);
ASSERT_EQ(half2.Size(), 2);
EXPECT_EQ(half2.Pop(), 4);
EXPECT_EQ(half2.Pop(), 5);
}

TEST(ThunkExecutorTest, LifoReadyQueueTest) {
ThunkExecutor::LifoReadyQueue queue({});

// Check basic queue properties.
EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);

queue.Push(1);
queue.Push(2);
queue.Push(3);

ASSERT_EQ(queue.Size(), 3);

EXPECT_EQ(queue.Pop(), 3);
EXPECT_EQ(queue.Pop(), 2);
EXPECT_EQ(queue.Pop(), 1);

EXPECT_TRUE(queue.Empty());
EXPECT_EQ(queue.Size(), 0);

// Prepare queue for PopHalf test case.
queue.Push(1);
queue.Push(2);
queue.Push(3);

// Pop half of the queue.
ThunkExecutor::LifoReadyQueue half0 = queue.PopHalf();
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 1);

// Check that the rest is still in the queue.
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::LifoReadyQueue half1 = queue.PopHalf();
ASSERT_EQ(half1.Size(), 1);

// ASSERT_EQ that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);

// Add 5 elements to test Pop followed by PopHalf.
queue.Push(1);
queue.Push(2);
queue.Push(3);
queue.Push(4);
queue.Push(5);

EXPECT_EQ(queue.Pop(), 5);

// Check that PopHalf returns first 2 nodes.
ThunkExecutor::LifoReadyQueue half2 = queue.PopHalf();
ASSERT_EQ(half2.Size(), 3);
EXPECT_EQ(half2.Pop(), 3);
EXPECT_EQ(half2.Pop(), 2);
EXPECT_EQ(half2.Pop(), 1);
}

TEST(ThunkExecutorTest, PriorityReadyQueueTest) {
std::vector<ThunkExecutor::NodeDef> nodes_defs(16);
for (size_t i = 0; i < nodes_defs.size(); ++i) {
Expand Down Expand Up @@ -326,20 +384,20 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) {

// Pop half of the queue.
ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf();
EXPECT_EQ(half0.Size(), 2);
ASSERT_EQ(half0.Size(), 2);
EXPECT_EQ(half0.Pop(), 2);
EXPECT_EQ(half0.Pop(), 1);

// Check that the rest is still in the queue.
EXPECT_EQ(queue.Size(), 1);
ASSERT_EQ(queue.Size(), 1);

// Pop the rest of the queue.
ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf();
EXPECT_EQ(half1.Size(), 1);
ASSERT_EQ(half1.Size(), 1);
EXPECT_EQ(half1.Pop(), 3);

// Check that all nodes were returned from PopHalf.
EXPECT_EQ(queue.Size(), 0);
ASSERT_EQ(queue.Size(), 0);

// Add 5 elements to test Pop followed by PopHalf.
queue.Push(4);
Expand All @@ -352,7 +410,7 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) {

// Check that PopHalf returns 2 last nodes.
ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf();
EXPECT_EQ(half2.Size(), 2);
ASSERT_EQ(half2.Size(), 2);
EXPECT_EQ(half2.Pop(), 2);
EXPECT_EQ(half2.Pop(), 1);
}
Expand Down

0 comments on commit b55bcbe

Please sign in to comment.