From 083b26f55582a1567d35a2b31e9c128911e2b3ba Mon Sep 17 00:00:00 2001 From: sitaowang1998 Date: Sun, 8 Dec 2024 17:23:01 -0500 Subject: [PATCH] feat: Add scheduler server that listens and responds to schedule task requests (#33) --- src/spider/CMakeLists.txt | 3 + src/spider/scheduler/SchedulerMessage.hpp | 56 +++++++++ src/spider/scheduler/SchedulerServer.cpp | 133 ++++++++++++++++++++++ src/spider/scheduler/SchedulerServer.hpp | 57 ++++++++++ src/spider/storage/MysqlStorage.cpp | 11 ++ tests/CMakeLists.txt | 1 + tests/scheduler/test-SchedulerPolicy.cpp | 8 -- tests/scheduler/test-SchedulerServer.cpp | 94 +++++++++++++++ tests/storage/test-DataStorage.cpp | 6 +- tests/utils/CoreTaskUtils.cpp | 7 +- 10 files changed, 364 insertions(+), 12 deletions(-) create mode 100644 src/spider/scheduler/SchedulerMessage.hpp create mode 100644 src/spider/scheduler/SchedulerServer.cpp create mode 100644 src/spider/scheduler/SchedulerServer.hpp create mode 100644 tests/scheduler/test-SchedulerServer.cpp diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index b4a0576..d550ca1 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -91,6 +91,9 @@ set(SPIDER_SCHEDULER_SOURCES scheduler/SchedulerPolicy.hpp scheduler/FifoPolicy.cpp scheduler/FifoPolicy.hpp + scheduler/SchedulerMessage.hpp + scheduler/SchedulerServer.cpp + scheduler/SchedulerServer.hpp CACHE INTERNAL "spider scheduler source files" ) diff --git a/src/spider/scheduler/SchedulerMessage.hpp b/src/spider/scheduler/SchedulerMessage.hpp new file mode 100644 index 0000000..9d2aa77 --- /dev/null +++ b/src/spider/scheduler/SchedulerMessage.hpp @@ -0,0 +1,56 @@ +#ifndef SPIDER_SCHEDULER_SCHEDULERMESSAGE_HPP +#define SPIDER_SCHEDULER_SCHEDULERMESSAGE_HPP + +#include +#include +#include + +#include + +#include "../io/MsgPack.hpp" // IWYU pragma: keep +#include "../io/Serializer.hpp" // IWYU pragma: keep + +namespace spider::scheduler { + +class ScheduleTaskRequest { +public: + /** + * Default constructor for msgpack. Do __not__ use it directly. + */ + ScheduleTaskRequest() = default; + + ScheduleTaskRequest(boost::uuids::uuid const worker_id, std::string addr) + : m_worker_id{worker_id}, + m_worker_addr{std::move(addr)} {} + + [[nodiscard]] auto get_worker_id() const -> boost::uuids::uuid { return m_worker_id; } + + [[nodiscard]] auto get_worker_addr() const -> std::string const& { return m_worker_addr; } + + MSGPACK_DEFINE_ARRAY(m_worker_id, m_worker_addr); + +private: + boost::uuids::uuid m_worker_id; + std::string m_worker_addr; +}; + +class ScheduleTaskResponse { +public: + ScheduleTaskResponse() = default; + + explicit ScheduleTaskResponse(boost::uuids::uuid const task_id) : m_task_id{task_id} {} + + [[nodiscard]] auto has_task_id() const -> bool { return m_task_id.has_value(); } + + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + [[nodiscard]] auto get_task_id() const -> boost::uuids::uuid { return m_task_id.value(); } + + MSGPACK_DEFINE_ARRAY(m_task_id); + +private: + std::optional m_task_id = std::nullopt; +}; + +} // namespace spider::scheduler + +#endif // SPIDER_SCHEDULER_SCHEDULERMESSAGE_HPP diff --git a/src/spider/scheduler/SchedulerServer.cpp b/src/spider/scheduler/SchedulerServer.cpp new file mode 100644 index 0000000..110cfc4 --- /dev/null +++ b/src/spider/scheduler/SchedulerServer.cpp @@ -0,0 +1,133 @@ +#include "SchedulerServer.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../io/BoostAsio.hpp" // IWYU pragma: keep +#include "../io/MsgPack.hpp" // IWYU pragma: keep +#include "../io/msgpack_message.hpp" +#include "../io/Serializer.hpp" // IWYU pragma: keep +#include "../storage/DataStorage.hpp" +#include "../storage/MetadataStorage.hpp" +#include "SchedulerMessage.hpp" +#include "SchedulerPolicy.hpp" + +namespace spider::scheduler { + +SchedulerServer::SchedulerServer( + unsigned short const port, + std::shared_ptr policy, + std::shared_ptr metadata_store, + std::shared_ptr data_store +) + : m_acceptor{m_context, {boost::asio::ip::tcp::v4(), port}}, + m_policy{std::move(policy)}, + m_metadata_store{std::move(metadata_store)}, + m_data_store{std::move(data_store)} { + // Ignore the returned future as we do not need its value + boost::asio::co_spawn(m_context, receive_message(), boost::asio::use_future); +} + +auto SchedulerServer::run() -> void { + m_context.run(); +} + +auto SchedulerServer::stop() -> void { + m_context.stop(); + std::lock_guard const lock{m_mutex}; + m_stop = true; +} + +auto SchedulerServer::receive_message() -> boost::asio::awaitable { + while (!should_stop()) { + // std::unique_ptr socket + // = std::make_unique(m_context); + boost::asio::ip::tcp::socket socket{m_context}; + auto const& [ec] = co_await m_acceptor.async_accept( + socket, + boost::asio::as_tuple(boost::asio::use_awaitable) + ); + if (ec) { + spdlog::error("Cannot accept connection {}: {}", ec.value(), ec.what()); + continue; + } + // Ignore the returned future as we do not need its value + boost::asio::co_spawn( + m_context, + process_message(std::move(socket)), + boost::asio::use_future + ); + } + co_return; +} + +namespace { +auto deserialize_message(msgpack::sbuffer const& buffer) -> std::optional { + try { + msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size()); + msgpack::object const object = handle.get(); + return object.as(); + } catch (std::runtime_error& e) { + spdlog::error("Cannot unpack message to ScheduleTaskRequest: {}", e.what()); + return std::nullopt; + } +} +} // namespace + +auto SchedulerServer::process_message(boost::asio::ip::tcp::socket socket +) -> boost::asio::awaitable { + // NOLINTBEGIN(clang-analyzer-core.CallAndMessage) + std::optional const& optional_message_buffer + = co_await core::receive_message_async(socket); + // NOLINTEND(clang-analyzer-core.CallAndMessage) + + if (false == optional_message_buffer.has_value()) { + spdlog::error("Cannot receive message from worker"); + co_return; + } + msgpack::sbuffer const& message_buffer = optional_message_buffer.value(); + std::optional const& optional_request + = deserialize_message(message_buffer); + if (false == optional_request.has_value()) { + spdlog::error("Cannot parse message into schedule task request"); + co_return; + } + ScheduleTaskRequest const& request = optional_request.value(); + + std::optional const task_id = m_policy->schedule_next( + m_metadata_store, + m_data_store, + request.get_worker_id(), + request.get_worker_addr() + ); + ScheduleTaskResponse response{}; + if (task_id.has_value()) { + response = ScheduleTaskResponse{task_id.value()}; + } + msgpack::sbuffer response_buffer; + msgpack::pack(response_buffer, response); + + bool const success = co_await core::send_message_async(socket, response_buffer); + if (!success) { + spdlog::error( + "Cannot send message to worker {} at {}", + boost::uuids::to_string(request.get_worker_id()), + request.get_worker_addr() + ); + } + co_return; +} + +auto SchedulerServer::should_stop() -> bool { + std::lock_guard const lock{m_mutex}; + return m_stop; +} + +} // namespace spider::scheduler diff --git a/src/spider/scheduler/SchedulerServer.hpp b/src/spider/scheduler/SchedulerServer.hpp new file mode 100644 index 0000000..1457b6d --- /dev/null +++ b/src/spider/scheduler/SchedulerServer.hpp @@ -0,0 +1,57 @@ +#ifndef SPIDER_SCHEDULER_SCHEDULERSERVER_HPP +#define SPIDER_SCHEDULER_SCHEDULERSERVER_HPP + +#include +#include + +#include "../io/BoostAsio.hpp" // IWYU pragma: keep +#include "../storage/DataStorage.hpp" +#include "../storage/MetadataStorage.hpp" +#include "SchedulerPolicy.hpp" + +namespace spider::scheduler { + +class SchedulerServer { +public: + // Delete copy & move constructor and assignment operator + SchedulerServer(SchedulerServer const&) = delete; + auto operator=(SchedulerServer const&) -> SchedulerServer& = delete; + SchedulerServer(SchedulerServer&&) = delete; + auto operator=(SchedulerServer&&) noexcept -> SchedulerServer& = delete; + ~SchedulerServer() = default; + + SchedulerServer( + unsigned short port, + std::shared_ptr policy, + std::shared_ptr metadata_store, + std::shared_ptr data_store + ); + + /** + * Run the server loop. This function blocks until stop is called. + */ + auto run() -> void; + + auto stop() -> void; + +private: + auto receive_message() -> boost::asio::awaitable; + + auto process_message(boost::asio::ip::tcp::socket socket) -> boost::asio::awaitable; + + auto should_stop() -> bool; + + std::shared_ptr m_policy; + std::shared_ptr m_metadata_store; + std::shared_ptr m_data_store; + + boost::asio::io_context m_context; + boost::asio::ip::tcp::acceptor m_acceptor; + + std::mutex m_mutex; + bool m_stop = false; +}; + +} // namespace spider::scheduler + +#endif // SPIDER_SCHEDULER_SCHEDULERSERVER_HPP diff --git a/src/spider/storage/MysqlStorage.cpp b/src/spider/storage/MysqlStorage.cpp index 0f3e756..213f2f4 100644 --- a/src/spider/storage/MysqlStorage.cpp +++ b/src/spider/storage/MysqlStorage.cpp @@ -473,6 +473,17 @@ auto MySqlMetadataStorage::add_job( dep_statement->setBytes(2, &child_id_bytes); dep_statement->executeUpdate(); } + + // Mark head tasks as ready + for (boost::uuids::uuid const& task_id : task_graph.get_head_tasks()) { + std::unique_ptr statement( + m_conn->prepareStatement("UPDATE `tasks` SET `state` = 'ready' WHERE `id` = ?") + ); + sql::bytes task_id_bytes = uuid_get_bytes(task_id); + statement->setBytes(1, &task_id_bytes); + statement->executeUpdate(); + } + } catch (sql::SQLException& e) { m_conn->rollback(); if (e.getErrorCode() == ErDupKey || e.getErrorCode() == ErDupEntry) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9d81394..494d9ea 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ set(SPIDER_TEST_SOURCES worker/test-TaskExecutor.cpp io/test-MsgpackMessage.cpp scheduler/test-SchedulerPolicy.cpp + scheduler/test-SchedulerServer.cpp CACHE INTERNAL "spider test source files" ) diff --git a/tests/scheduler/test-SchedulerPolicy.cpp b/tests/scheduler/test-SchedulerPolicy.cpp index 6e2b823..eec1d84 100644 --- a/tests/scheduler/test-SchedulerPolicy.cpp +++ b/tests/scheduler/test-SchedulerPolicy.cpp @@ -44,16 +44,12 @@ TEMPLATE_LIST_TEST_CASE( graph_1.add_task(task_1); boost::uuids::uuid const job_id_1 = gen(); REQUIRE(metadata_store->add_job(job_id_1, client_id, graph_1).success()); - REQUIRE(metadata_store->set_task_state(task_1.get_id(), spider::core::TaskState::Ready) - .success()); std::this_thread::sleep_for(std::chrono::seconds(1)); spider::core::Task const task_2{"task_2"}; spider::core::TaskGraph graph_2; graph_2.add_task(task_2); boost::uuids::uuid const job_id_2 = gen(); REQUIRE(metadata_store->add_job(job_id_2, client_id, graph_2).success()); - REQUIRE(metadata_store->set_task_state(task_2.get_id(), spider::core::TaskState::Ready) - .success()); spider::scheduler::FifoPolicy policy; @@ -97,8 +93,6 @@ TEMPLATE_LIST_TEST_CASE( spider::core::TaskGraph graph; graph.add_task(task); REQUIRE(metadata_store->add_job(job_id, gen(), graph).success()); - REQUIRE(metadata_store->set_task_state(task.get_id(), spider::core::TaskState::Ready).success() - ); spider::scheduler::FifoPolicy policy; // Schedule with wrong address @@ -142,8 +136,6 @@ TEMPLATE_LIST_TEST_CASE( spider::core::TaskGraph graph; graph.add_task(task); REQUIRE(metadata_store->add_job(job_id, gen(), graph).success()); - REQUIRE(metadata_store->set_task_state(task.get_id(), spider::core::TaskState::Ready).success() - ); spider::scheduler::FifoPolicy policy; // Schedule with wrong address diff --git a/tests/scheduler/test-SchedulerServer.cpp b/tests/scheduler/test-SchedulerServer.cpp new file mode 100644 index 0000000..dc88a0d --- /dev/null +++ b/tests/scheduler/test-SchedulerServer.cpp @@ -0,0 +1,94 @@ +// NOLINTBEGIN(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-analyzer-optin.core.EnumCastOutOfRange) +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../src/spider/core/Task.hpp" +#include "../../src/spider/core/TaskGraph.hpp" +#include "../../src/spider/io/BoostAsio.hpp" // IWYU pragma: keep +#include "../../src/spider/io/MsgPack.hpp" // IWYU pragma: keep +#include "../../src/spider/io/msgpack_message.hpp" +#include "../../src/spider/scheduler/FifoPolicy.hpp" +#include "../../src/spider/scheduler/SchedulerMessage.hpp" +#include "../../src/spider/scheduler/SchedulerPolicy.hpp" +#include "../../src/spider/scheduler/SchedulerServer.hpp" +#include "../../src/spider/storage/DataStorage.hpp" +#include "../../src/spider/storage/MetadataStorage.hpp" +#include "../storage/StorageTestHelper.hpp" + +namespace { +TEMPLATE_LIST_TEST_CASE( + "Scheduler server test", + "[scheduler][server][storage]", + spider::test::StorageTypeList +) { + std::tuple< + std::unique_ptr, + std::unique_ptr> + storages = spider::test::create_storage< + std::tuple_element_t<0, TestType>, + std::tuple_element_t<1, TestType>>(); + std::shared_ptr const metadata_store + = std::move(std::get<0>(storages)); + std::shared_ptr const data_store = std::move(std::get<1>(storages)); + + std::shared_ptr const policy + = std::make_shared(); + + constexpr unsigned short cPort = 6021; + spider::scheduler::SchedulerServer server{cPort, policy, metadata_store, data_store}; + + // Start server in another thread + std::thread thread{[&]() { server.run(); }}; + + // Create client socket + boost::asio::io_context context; + boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::tcp::v4(), cPort}; + boost::asio::ip::tcp::socket socket{context}; + boost::asio::connect(socket, std::vector{endpoint}); + + // Add task to storage + spider::core::Task const parent_task{"parent"}; + spider::core::Task const child_task{"child"}; + spider::core::TaskGraph graph; + graph.add_task(parent_task); + graph.add_task(child_task); + graph.add_dependency(parent_task.get_id(), child_task.get_id()); + boost::uuids::random_generator gen; + boost::uuids::uuid const job_id = gen(); + REQUIRE(metadata_store->add_job(job_id, gen(), graph).success()); + + // Schedule request should succeed + spider::scheduler::ScheduleTaskRequest const req{gen(), ""}; + msgpack::sbuffer req_buffer; + msgpack::pack(req_buffer, req); + REQUIRE(spider::core::send_message(socket, req_buffer)); + + // Get response should succeed and get child task + std::optional const& res_buffer = spider::core::receive_message(socket); + REQUIRE(metadata_store->remove_job(job_id).success()); + REQUIRE(res_buffer.has_value()); + if (res_buffer.has_value()) { + msgpack::object_handle const handle + = msgpack::unpack(res_buffer.value().data(), res_buffer.value().size()); + msgpack::object const object = handle.get(); + spider::scheduler::ScheduleTaskResponse const res + = object.as(); + REQUIRE(res.has_task_id()); + REQUIRE(res.get_task_id() == parent_task.get_id()); + } + socket.close(); + server.stop(); + thread.join(); +} +} // namespace + +// NOLINTEND(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-analyzer-optin.core.EnumCastOutOfRange) diff --git a/tests/storage/test-DataStorage.cpp b/tests/storage/test-DataStorage.cpp index 6e0e548..5841e66 100644 --- a/tests/storage/test-DataStorage.cpp +++ b/tests/storage/test-DataStorage.cpp @@ -91,7 +91,8 @@ TEMPLATE_LIST_TEST_CASE( spider::core::Task const task{"func"}; spider::core::TaskGraph graph; graph.add_task(task); - REQUIRE(metadata_storage->add_job(gen(), gen(), graph).success()); + boost::uuids::uuid const job_id = gen(); + REQUIRE(metadata_storage->add_job(job_id, gen(), graph).success()); // Add task reference without data should fail. REQUIRE(!data_storage->add_task_reference(gen(), task.get_id()).success()); @@ -105,6 +106,9 @@ TEMPLATE_LIST_TEST_CASE( // Remove task reference REQUIRE(data_storage->remove_task_reference(data.get_id(), task.get_id()).success()); + + // Remove job + REQUIRE(metadata_storage->remove_job(job_id).success()); } TEMPLATE_LIST_TEST_CASE( diff --git a/tests/utils/CoreTaskUtils.cpp b/tests/utils/CoreTaskUtils.cpp index 5e7dc74..f0df8b2 100644 --- a/tests/utils/CoreTaskUtils.cpp +++ b/tests/utils/CoreTaskUtils.cpp @@ -141,9 +141,10 @@ auto task_equal(core::Task const& t1, core::Task const& t2) -> bool { if (t1.get_function_name() != t2.get_function_name()) { return false; } - if (t1.get_state() != t2.get_state()) { - return false; - } + // Task state might not be the same + // if (t1.get_state() != t2.get_state()) { + // return false; + // } if (!float_equal(t1.get_timeout(), t2.get_timeout())) { return false; }