diff --git a/xla/backends/cpu/runtime/thunk_executor.cc b/xla/backends/cpu/runtime/thunk_executor.cc index 97625473b44200..3f02a089e5236a 100644 --- a/xla/backends/cpu/runtime/thunk_executor.cc +++ b/xla/backends/cpu/runtime/thunk_executor.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -220,12 +221,20 @@ tsl::AsyncValueRef ThunkExecutor::Execute( // This also works for thunks with nested thunk executors (i.e., WhileThunk), // as launching nested thunk sequence must not reduce the available // concurrency for the other thunks executing in parallel. - if (options_.use_priority_ready_queue) { - Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_), - /*lock=*/nullptr); - } else { - Execute(state.get(), params, FifoReadyQueue(source_), - /*lock=*/nullptr); + auto execute = [&](auto ready_queue) { + Execute(state.get(), params, std::move(ready_queue), /*lock=*/nullptr); + }; + + switch (options_.ready_queue_type) { + case Options::ReadyQueueType::kFifo: + execute(FifoReadyQueue(source_)); + break; + case Options::ReadyQueueType::kLifo: + execute(LifoReadyQueue(source_)); + break; + case Options::ReadyQueueType::kPriority: + execute(PriorityReadyQueue(nodes_defs_, source_)); + break; } // If execution already completed (all kernels executed in the caller thread), @@ -294,7 +303,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); - } else if (!runner || runner->current_worker_id()) { + } else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) { // Resume execution in the current thread if we are already running // on a thread managed by the task runner. ResumeExecuteSequential(it + 1, params, std::move(event)); @@ -334,23 +343,23 @@ void ThunkExecutor::ResumeExecuteSequential( // If thunk execution is not completed yet, attach a continuation to // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { - execute_event.AndThen( - [this, ¶ms, it, event = std::move(event)](absl::Status status) { - Thunk::TaskRunner* runner = params.task_runner; + execute_event.AndThen([this, ¶ms, it, + event = std::move(event)](absl::Status status) { + Thunk::TaskRunner* runner = params.task_runner; - if (ABSL_PREDICT_FALSE(!status.ok())) { - event.SetError(std::move(status)); - } else if (!runner || runner->current_worker_id()) { - // Resume execution in the current thread if we are already - // running on a thread managed by the task runner. - ResumeExecuteSequential(it + 1, params, std::move(event)); - } else { - // Resume execution in the task runner to avoid thread "leaks". - (*runner)([this, ¶ms, it, event = std::move(event)] { - ResumeExecuteSequential(it + 1, params, std::move(event)); - }); - } + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + } else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) { + // Resume execution in the current thread if we are already + // running on a thread managed by the task runner. + ResumeExecuteSequential(it + 1, params, std::move(event)); + } else { + // Resume execution in the task runner to avoid thread "leaks". + (*runner)([this, ¶ms, it, event = std::move(event)] { + ResumeExecuteSequential(it + 1, params, std::move(event)); }); + } + }); return; } @@ -443,7 +452,7 @@ void ThunkExecutor::Execute(ExecuteState* state, } Thunk::TaskRunner* runner = state->runner; - if (!runner || runner->current_worker_id()) { + if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) { // Resume execution in the current thread if we are already // running on a thread managed by the task runner. state->executor->Execute(state, params, std::move(ready_queue), @@ -740,6 +749,39 @@ ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const { return FifoReadyQueue(absl::Span()); } +ThunkExecutor::LifoReadyQueue::LifoReadyQueue( + absl::Span ready_nodes) + : queue_(ready_nodes.begin(), ready_nodes.end()) {} + +void ThunkExecutor::LifoReadyQueue::Push(NodeId id) { queue_.push_back(id); } + +ThunkExecutor::NodeId ThunkExecutor::LifoReadyQueue::Pop() { + DCHECK(!Empty()) << "Queue must not be empty"; + NodeId id = queue_.back(); + queue_.pop_back(); + return id; +} + +ThunkExecutor::LifoReadyQueue ThunkExecutor::LifoReadyQueue::PopHalf() { + DCHECK(!Empty()) << "Queue must not be empty"; + auto mid = Size() / 2 + 1; + LifoReadyQueue popped( + absl::MakeConstSpan(queue_.begin(), queue_.begin() + mid)); + + std::move(queue_.begin() + mid, queue_.end(), queue_.begin()); + queue_.resize(queue_.size() - mid); + return popped; +} + +size_t ThunkExecutor::LifoReadyQueue::Size() const { return queue_.size(); } + +bool ThunkExecutor::LifoReadyQueue::Empty() const { return queue_.empty(); } + +ThunkExecutor::LifoReadyQueue +ThunkExecutor::LifoReadyQueue::CreateEmptyReadyQueue() const { + return LifoReadyQueue(absl::Span()); +} + ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue( absl::Span nodes_defs, absl::Span ready_nodes) : nodes_defs_(nodes_defs), diff --git a/xla/backends/cpu/runtime/thunk_executor.h b/xla/backends/cpu/runtime/thunk_executor.h index 88adf96b65bea5..a29b7b29453efa 100644 --- a/xla/backends/cpu/runtime/thunk_executor.h +++ b/xla/backends/cpu/runtime/thunk_executor.h @@ -44,6 +44,8 @@ namespace internal { // Clang does not allow defining a nested struct with member initializer, as // a workaround we define a struct in internal namespace and create an alias. struct ThunkExecutorOptions { + enum class ReadyQueueType { kFifo, kLifo, kPriority }; + // If all thunks in a sequence use buffers of size less than or equal to the // given threshold, we mark execution as sequential, as concurrency overheads // will likely dominate the overall execution time. @@ -54,9 +56,8 @@ struct ThunkExecutorOptions { // the overall execution time. size_t execute_sequential_num_thunks_threshold = 8; - // Use priority ready queue to execute nodes according to their priority. By - // default we use FIFO ready queue. - bool use_priority_ready_queue = false; + // The type of a queue for ready thunks. + ReadyQueueType ready_queue_type = ReadyQueueType::kFifo; }; } // namespace internal @@ -146,6 +147,25 @@ class ThunkExecutor { size_t head_ = 0; }; + // A ready queue that executes nodes in LIFO order. + class LifoReadyQueue { + public: + explicit LifoReadyQueue(absl::Span ready_nodes); + + void Push(NodeId id); + + NodeId Pop(); + LifoReadyQueue PopHalf(); + + size_t Size() const; + bool Empty() const; + + LifoReadyQueue CreateEmptyReadyQueue() const; + + private: + absl::InlinedVector queue_; + }; + // A ready queue that executes nodes sorted by NodeDef priority. class PriorityReadyQueue { public: diff --git a/xla/backends/cpu/runtime/thunk_executor_test.cc b/xla/backends/cpu/runtime/thunk_executor_test.cc index dd315236916dd1..fee3485030db2e 100644 --- a/xla/backends/cpu/runtime/thunk_executor_test.cc +++ b/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -251,14 +251,14 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) { queue.Push(2); queue.Push(3); - EXPECT_EQ(queue.Size(), 3); + ASSERT_EQ(queue.Size(), 3); EXPECT_EQ(queue.Pop(), 1); EXPECT_EQ(queue.Pop(), 2); EXPECT_EQ(queue.Pop(), 3); EXPECT_TRUE(queue.Empty()); - EXPECT_EQ(queue.Size(), 0); + ASSERT_EQ(queue.Size(), 0); // Prepare queue for PopHalf test case. queue.Push(1); @@ -267,16 +267,16 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) { // Pop half of the queue. ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf(); - EXPECT_EQ(half0.Size(), 2); + ASSERT_EQ(half0.Size(), 2); EXPECT_EQ(half0.Pop(), 2); EXPECT_EQ(half0.Pop(), 3); // Check that the rest is still in the queue. - EXPECT_EQ(queue.Size(), 1); + ASSERT_EQ(queue.Size(), 1); // Pop the rest of the queue. ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf(); - EXPECT_EQ(half1.Size(), 1); + ASSERT_EQ(half1.Size(), 1); // Check that all nodes were returned from PopHalf. EXPECT_EQ(queue.Size(), 0); @@ -292,11 +292,69 @@ TEST(ThunkExecutorTest, FifoReadyQueueTest) { // Check that PopHalf returns 2 last nodes. ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf(); - EXPECT_EQ(half2.Size(), 2); + ASSERT_EQ(half2.Size(), 2); EXPECT_EQ(half2.Pop(), 4); EXPECT_EQ(half2.Pop(), 5); } +TEST(ThunkExecutorTest, LifoReadyQueueTest) { + ThunkExecutor::LifoReadyQueue queue({}); + + // Check basic queue properties. + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + queue.Push(1); + queue.Push(2); + queue.Push(3); + + ASSERT_EQ(queue.Size(), 3); + + EXPECT_EQ(queue.Pop(), 3); + EXPECT_EQ(queue.Pop(), 2); + EXPECT_EQ(queue.Pop(), 1); + + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + // Prepare queue for PopHalf test case. + queue.Push(1); + queue.Push(2); + queue.Push(3); + + // Pop half of the queue. + ThunkExecutor::LifoReadyQueue half0 = queue.PopHalf(); + ASSERT_EQ(half0.Size(), 2); + EXPECT_EQ(half0.Pop(), 2); + EXPECT_EQ(half0.Pop(), 1); + + // Check that the rest is still in the queue. + ASSERT_EQ(queue.Size(), 1); + + // Pop the rest of the queue. + ThunkExecutor::LifoReadyQueue half1 = queue.PopHalf(); + ASSERT_EQ(half1.Size(), 1); + + // ASSERT_EQ that all nodes were returned from PopHalf. + EXPECT_EQ(queue.Size(), 0); + + // Add 5 elements to test Pop followed by PopHalf. + queue.Push(1); + queue.Push(2); + queue.Push(3); + queue.Push(4); + queue.Push(5); + + EXPECT_EQ(queue.Pop(), 5); + + // Check that PopHalf returns first 2 nodes. + ThunkExecutor::LifoReadyQueue half2 = queue.PopHalf(); + ASSERT_EQ(half2.Size(), 3); + EXPECT_EQ(half2.Pop(), 3); + EXPECT_EQ(half2.Pop(), 2); + EXPECT_EQ(half2.Pop(), 1); +} + TEST(ThunkExecutorTest, PriorityReadyQueueTest) { std::vector nodes_defs(16); for (size_t i = 0; i < nodes_defs.size(); ++i) { @@ -326,20 +384,20 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) { // Pop half of the queue. ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf(); - EXPECT_EQ(half0.Size(), 2); + ASSERT_EQ(half0.Size(), 2); EXPECT_EQ(half0.Pop(), 2); EXPECT_EQ(half0.Pop(), 1); // Check that the rest is still in the queue. - EXPECT_EQ(queue.Size(), 1); + ASSERT_EQ(queue.Size(), 1); // Pop the rest of the queue. ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf(); - EXPECT_EQ(half1.Size(), 1); + ASSERT_EQ(half1.Size(), 1); EXPECT_EQ(half1.Pop(), 3); // Check that all nodes were returned from PopHalf. - EXPECT_EQ(queue.Size(), 0); + ASSERT_EQ(queue.Size(), 0); // Add 5 elements to test Pop followed by PopHalf. queue.Push(4); @@ -352,7 +410,7 @@ TEST(ThunkExecutorTest, PriorityReadyQueueTest) { // Check that PopHalf returns 2 last nodes. ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf(); - EXPECT_EQ(half2.Size(), 2); + ASSERT_EQ(half2.Size(), 2); EXPECT_EQ(half2.Pop(), 2); EXPECT_EQ(half2.Pop(), 1); }