Skip to content

Commit

Permalink
feat: Add TaskContext and Data as task argument (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
sitaowang1998 authored Dec 21, 2024
1 parent 0a896a8 commit 03d043a
Show file tree
Hide file tree
Showing 20 changed files with 649 additions and 60 deletions.
9 changes: 8 additions & 1 deletion src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ set(SPIDER_CORE_SOURCES
set(SPIDER_CORE_HEADERS
core/Error.hpp
core/Data.hpp
core/DataImpl.hpp
core/Driver.hpp
core/KeyValueData.hpp
core/Task.hpp
core/TaskContextImpl.hpp
core/TaskGraph.hpp
core/JobMetadata.hpp
io/BoostAsio.hpp
Expand Down Expand Up @@ -83,7 +85,12 @@ set(SPIDER_TASK_EXECUTOR_SOURCES

add_executable(spider_task_executor)
target_sources(spider_task_executor PRIVATE ${SPIDER_TASK_EXECUTOR_SOURCES})
target_link_libraries(spider_task_executor PRIVATE spider_core)
target_link_libraries(
spider_task_executor
PRIVATE
spider_core
spider_client_lib
)
target_link_libraries(
spider_task_executor
PRIVATE
Expand Down
99 changes: 93 additions & 6 deletions src/spider/client/Data.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
#ifndef SPIDER_CLIENT_DATA_HPP
#define SPIDER_CLIENT_DATA_HPP

#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <boost/uuid/uuid.hpp>

#include "../core/Error.hpp"
#include "../io/MsgPack.hpp" // IWYU pragma: keep
#include "../io/Serializer.hpp"
#include "../storage/DataStorage.hpp"
#include "Exception.hpp"

namespace spider {

namespace core {
class Data;
class DataStorage;
class DataImpl;
} // namespace core

/**
* A representation of data stored on external storage. This class allows the user to define:
Expand All @@ -32,16 +45,24 @@ class Data {
/**
* @return The stored value.
*/
auto get() -> T;
auto get() -> T {
std::string const& value = m_impl->get_value();
return msgpack::unpack(value.data(), value.size()).get().as<T>();
}

/**
* Sets the data's locality, indicated by the nodes that contain the data.
*
* @param nodes
* @param hard Whether the data is only accessible from the given nodes (i.e., the locality is a
* hard requirement).
* @throw spider::ConnectionException
*/
void set_locality(std::vector<std::string> const& nodes, bool hard);
void set_locality(std::vector<std::string> const& nodes, bool hard) {
m_impl->set_locality(nodes);
m_impl->set_hard_locality(hard);
m_data_store->set_data_locality(*m_impl);
}

class Builder {
public:
Expand All @@ -53,7 +74,11 @@ class Data {
* is a hard requirement.
* @return self
*/
auto set_locality(std::vector<std::string> const& nodes, bool hard) -> Builder&;
auto set_locality(std::vector<std::string> const& nodes, bool hard) -> Builder& {
m_nodes = nodes;
m_hard_locality = hard;
return *this;
}

/**
* Sets the cleanup function for the data. This function will be called when the data is no
Expand All @@ -62,7 +87,10 @@ class Data {
* @param f
* @return self
*/
auto set_cleanup_func(std::function<void(T const&)> const& f) -> Builder&;
auto set_cleanup_func(std::function<void(T const&)> const& f) -> Builder& {
m_cleanup_func = f;
return *this;
}

/**
* Builds the data object.
Expand All @@ -71,11 +99,70 @@ class Data {
* @return The built object.
* @throw spider::ConnectionException
*/
auto build(T const& t) -> Data;
auto build(T const& t) -> Data {
msgpack::sbuffer buffer;
msgpack::pack(buffer, t);
auto data = std::make_unique<core::Data>(std::string{buffer.data(), buffer.size()});
data->set_locality(m_nodes);
data->set_hard_locality(m_hard_locality);
core::StorageErr err;
switch (m_data_source) {
case DataSource::Driver:
err = m_data_store->add_driver_data(m_source_id, *data);
if (!err.success()) {
throw ConnectionException(err.description);
}
break;
case DataSource::TaskContext:
err = m_data_store->add_task_data(m_source_id, *data);
if (!err.success()) {
throw ConnectionException(err.description);
}
break;
}
return Data{data, m_data_store};
}

private:
enum class DataSource : std::uint8_t {
Driver,
TaskContext
};

explicit Builder(
std::shared_ptr<core::DataStorage> data_store,
boost::uuids::uuid const source_id,
DataSource const data_source
)
: m_data_store{std::move(data_store)},
m_source_id{source_id},
m_data_source{data_source} {}

std::vector<std::string> m_nodes;
bool m_hard_locality = false;
std::function<void(T const&)> m_cleanup_func;

std::shared_ptr<core::DataStorage> m_data_store;
boost::uuids::uuid m_source_id;
DataSource m_data_source;

friend class Driver;
friend class TaskContext;
};

Data() = default;

private:
std::unique_ptr<DataImpl> m_impl;
Data(std::unique_ptr<core::Data> impl, std::shared_ptr<core::DataStorage> data_store)
: m_impl{std::move(impl)},
m_data_store{std::move(data_store)} {}

[[nodiscard]] auto get_impl() const -> std::unique_ptr<core::Data> const& { return m_impl; }

std::unique_ptr<core::Data> m_impl;
std::shared_ptr<core::DataStorage> m_data_store;

friend class core::DataImpl;
};
} // namespace spider

Expand Down
8 changes: 8 additions & 0 deletions src/spider/client/Driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

#include <boost/uuid/uuid.hpp>

#include "../io/Serializer.hpp"
#include "../worker/FunctionManager.hpp"
#include "Data.hpp"
#include "Job.hpp"
#include "task.hpp"
#include "TaskGraph.hpp"
Expand Down Expand Up @@ -52,6 +54,12 @@ class Driver {
*/
Driver(std::string const& storage_url, boost::uuids::uuid id);

/**
* @return Data builder.
*/
template <Serializable T>
auto get_data_builder() -> Data<T>::Builder;

/**
* Inserts the given key-value pair into the key-value store, overwriting any existing value.
*
Expand Down
35 changes: 35 additions & 0 deletions src/spider/client/TaskContext.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
#ifndef SPIDER_CLIENT_TASKCONTEXT_HPP
#define SPIDER_CLIENT_TASKCONTEXT_HPP

#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <boost/uuid/uuid.hpp>

#include "../io/Serializer.hpp"
#include "Data.hpp"
#include "Job.hpp"
#include "task.hpp"
#include "TaskGraph.hpp"

namespace spider {
namespace core {
class DataStorage;
class MetadataStorage;
class TaskContextImpl;
} // namespace core

/**
* TaskContext provides a task with all Spider functionalities, e.g. getting task instance id,
* accessing data storage, creating and waiting for new jobs, etc.
Expand All @@ -32,6 +42,12 @@ class TaskContext {
*/
[[nodiscard]] auto get_id() const -> boost::uuids::uuid;

/**
* @return Data builder.
*/
template <Serializable T>
auto get_data_builder() -> Data<T>::Builder;

/**
* Inserts the given key-value pair into the key-value store, overwriting any existing value.
*
Expand Down Expand Up @@ -111,6 +127,25 @@ class TaskContext {
* @throw spider::ConnectionException
*/
auto get_jobs() -> std::vector<boost::uuids::uuid>;

TaskContext() = default;

private:
TaskContext(
std::shared_ptr<core::DataStorage> data_store,
std::shared_ptr<core::MetadataStorage> metadata_store
)
: m_data_store{std::move(data_store)},
m_metadata_store{std::move(metadata_store)} {}

auto get_data_store() -> std::shared_ptr<core::DataStorage> { return m_data_store; }

auto get_metadata_store() -> std::shared_ptr<core::MetadataStorage> { return m_metadata_store; }

std::shared_ptr<core::DataStorage> m_data_store;
std::shared_ptr<core::MetadataStorage> m_metadata_store;

friend class core::TaskContextImpl;
};
} // namespace spider

Expand Down
4 changes: 1 addition & 3 deletions src/spider/core/Data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ class Data {

Data(boost::uuids::uuid const id, std::string value) : m_id(id), m_value(std::move(value)) {}

static auto is_data() -> bool { return true; }

[[nodiscard]] auto get_id() const -> boost::uuids::uuid { return m_id; }

[[nodiscard]] auto get_value() const -> std::string { return m_value; }
[[nodiscard]] auto get_value() const -> std::string const& { return m_value; }

[[nodiscard]] auto get_locality() const -> std::vector<std::string> const& {
return m_locality;
Expand Down
28 changes: 28 additions & 0 deletions src/spider/core/DataImpl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef SPIDER_CORE_DATAIMPL_HPP
#define SPIDER_CORE_DATAIMPL_HPP

#include <memory>
#include <utility>

#include "../client/Data.hpp"
#include "../core/Data.hpp"

namespace spider::core {

class DataImpl {
public:
template <class T>
static auto create_data(std::unique_ptr<Data> data, std::shared_ptr<DataStorage> data_store)
-> spider::Data<T> {
return spider::Data<T>{std::move(data), data_store};
}

template <class T>
static auto get_impl(spider::Data<T> const& data) -> std::shared_ptr<DataStorage> {
return data.get_impl();
}
};

} // namespace spider::core

#endif
32 changes: 32 additions & 0 deletions src/spider/core/TaskContextImpl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef SPIDER_CORE_TASKCONTEXTIMPL_HPP
#define SPIDER_CORE_TASKCONTEXTIMPL_HPP

#include <memory>

#include "../client/TaskContext.hpp"
#include "../storage/DataStorage.hpp"
#include "../storage/MetadataStorage.hpp"

namespace spider::core {
class TaskContextImpl {
public:
static auto create_task_context(
std::shared_ptr<DataStorage> const& data_storage,
std::shared_ptr<MetadataStorage> const& metadata_storage
) -> TaskContext {
return TaskContext{data_storage, metadata_storage};
}

static auto get_data_store(TaskContext const& task_context) -> std::shared_ptr<DataStorage> {
return task_context.m_data_store;
}

static auto get_metadata_store(TaskContext const& task_context
) -> std::shared_ptr<MetadataStorage> {
return task_context.m_metadata_store;
}
};

} // namespace spider::core

#endif
1 change: 1 addition & 0 deletions src/spider/storage/DataStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DataStorage {
virtual auto add_driver_data(boost::uuids::uuid driver_id, Data const& data) -> StorageErr = 0;
virtual auto add_task_data(boost::uuids::uuid task_id, Data const& data) -> StorageErr = 0;
virtual auto get_data(boost::uuids::uuid id, Data* data) -> StorageErr = 0;
virtual auto set_data_locality(Data const& data) -> StorageErr = 0;
virtual auto remove_data(boost::uuids::uuid id) -> StorageErr = 0;
virtual auto add_task_reference(boost::uuids::uuid id, boost::uuids::uuid task_id) -> StorageErr
= 0;
Expand Down
30 changes: 30 additions & 0 deletions src/spider/storage/MysqlStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,36 @@ auto MySqlDataStorage::get_data(boost::uuids::uuid id, Data* data) -> StorageErr
return StorageErr{};
}

auto MySqlDataStorage::set_data_locality(Data const& data) -> StorageErr {
try {
std::unique_ptr<sql::PreparedStatement> const delete_statement(
m_conn->prepareStatement("DELETE FROM `data_locality` WHERE `id` = ?")
);
sql::bytes id_bytes = uuid_get_bytes(data.get_id());
delete_statement->setBytes(1, &id_bytes);
delete_statement->executeUpdate();
std::unique_ptr<sql::PreparedStatement> const insert_statement(m_conn->prepareStatement(
"INSERT INTO `data_locality` (`id`, `address`) VALUES(?, ?)"
));
for (std::string const& addr : data.get_locality()) {
insert_statement->setBytes(1, &id_bytes);
insert_statement->setString(2, addr);
insert_statement->executeUpdate();
}
std::unique_ptr<sql::PreparedStatement> const hard_locality_statement(
m_conn->prepareStatement("UPDATE `data` SET `hard_locality` = ? WHERE `id` = ?")
);
hard_locality_statement->setBoolean(1, data.is_hard_locality());
hard_locality_statement->setBytes(2, &id_bytes);
hard_locality_statement->executeUpdate();
} catch (sql::SQLException& e) {
m_conn->rollback();
return StorageErr{StorageErrType::OtherErr, e.what()};
}
m_conn->commit();
return StorageErr{};
}

auto MySqlDataStorage::remove_data(boost::uuids::uuid id) -> StorageErr {
try {
std::unique_ptr<sql::PreparedStatement> statement(
Expand Down
1 change: 1 addition & 0 deletions src/spider/storage/MysqlStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class MySqlDataStorage : public DataStorage {
auto add_driver_data(boost::uuids::uuid driver_id, Data const& data) -> StorageErr override;
auto add_task_data(boost::uuids::uuid task_id, Data const& data) -> StorageErr override;
auto get_data(boost::uuids::uuid id, Data* data) -> StorageErr override;
auto set_data_locality(Data const& data) -> StorageErr override;
auto remove_data(boost::uuids::uuid id) -> StorageErr override;
auto
add_task_reference(boost::uuids::uuid id, boost::uuids::uuid task_id) -> StorageErr override;
Expand Down
Loading

0 comments on commit 03d043a

Please sign in to comment.