From 94cd02e4be639d048b52b7307597dd4e4b4c0079 Mon Sep 17 00:00:00 2001 From: Daniel Connelly Date: Mon, 2 Dec 2024 22:19:34 +0100 Subject: [PATCH] parse safetensors, model, tokenizer, generation configs --- src/http/server.cc | 15 +++----------- src/http/server.h | 3 +-- src/inference/generator.cc | 22 +++++++++++++------- src/inference/generator.h | 30 +++++++++++++++++++++++++-- src/inference/safetensors.cc | 39 +++++++++++++++++++++++++++++++++++ src/inference/safetensors.h | 30 +++++++++++++++++++++++++++ src/json/parser.cc | 14 ++++++++++++- src/json/parser.h | 2 ++ src/json/parser_test.cc | 8 +++++--- src/service.cc | 3 ++- src/utils/pointers.cc | 40 ++++++++++++++++++++++++++++++++++++ src/utils/pointers.h | 32 +++++++++++++++++++++++++++++ 12 files changed, 210 insertions(+), 28 deletions(-) create mode 100644 src/inference/safetensors.cc create mode 100644 src/inference/safetensors.h create mode 100644 src/utils/pointers.cc create mode 100644 src/utils/pointers.h diff --git a/src/http/server.cc b/src/http/server.cc index 96ab226..e1759d6 100644 --- a/src/http/server.cc +++ b/src/http/server.cc @@ -125,27 +125,18 @@ void MustSend(ResponseWriter& resp, StatusCode status) noexcept { } } -OwnedFd to_owned(int fd) { - return OwnedFd(new int(fd), [](int* fdp) { - if (fdp && *fdp >= 0) { - close(*fdp); - delete fdp; - } - }); -} - OwnedFd ServerSocket() { int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd < 0) throw SystemError(errno); int reuse = 1; setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); - return to_owned(fd); + return Own(fd); } std::array MakePipe() { int fds[2]; if (::pipe(fds) < 0) throw SystemError(errno); - return {to_owned(fds[0]), to_owned(fds[1])}; + return {Own(fds[0]), Own(fds[1])}; } class SocketWriter : public ResponseWriter { @@ -340,7 +331,7 @@ void HttpServer::Accept() { } char ip[INET_ADDRSTRLEN]; inet_ntop(AF_INET, &client_addr.sin_addr, ip, INET_ADDRSTRLEN); - HttpServer::Client client{to_owned(client_fd), ntohs(client_addr.sin_port), + HttpServer::Client client{Own(client_fd), ntohs(client_addr.sin_port), std::string(ip)}; // hand it off to a worker thread diff --git a/src/http/server.h b/src/http/server.h index 7637e62..bbca678 100644 --- a/src/http/server.h +++ b/src/http/server.h @@ -9,6 +9,7 @@ #include "http/thread_pool.h" #include "http/types.h" +#include "utils/pointers.h" namespace gabby { namespace http { @@ -23,8 +24,6 @@ struct ServerConfig { std::ostream& operator<<(std::ostream& os, const ServerConfig& config); -using OwnedFd = std::unique_ptr; - class HttpServer { public: explicit HttpServer(const ServerConfig& config); diff --git a/src/inference/generator.cc b/src/inference/generator.cc index 7cb009e..3bbd590 100644 --- a/src/inference/generator.cc +++ b/src/inference/generator.cc @@ -1,8 +1,11 @@ #include "inference/generator.h" +#include +#include #include #include +#include "json/parser.h" #include "utils/logging.h" namespace gabby { @@ -27,20 +30,25 @@ std::ostream& operator<<(std::ostream& os, const Request& msg) { to_string(msg.user_message)); } -Message Generator::Generate(const Request& req) { +Message Llama3Generator::Generate(const Request& req) { return Message{ .role = "assistant", .content = "hey this is gabby, how are u", }; } -/* static */ std::unique_ptr Generator::LoadFromDirectory( +/* static */ +std::unique_ptr Llama3Generator::LoadFromDirectory( std::filesystem::path dir) { - fs::directory_iterator contents(dir); - for (fs::path file : contents) { - LOG(DEBUG) << "scanning: " << file.string(); - } - return std::make_unique(); + auto config = json::ParseFile(dir / "config.json"); + auto gen_config = json::ParseFile(dir / "generation_config.json"); + auto special_tokens_map = json::ParseFile(dir / "special_tokens_map.json"); + auto tok_config = json::ParseFile(dir / "tokenizer_config.json"); + auto tok = json::ParseFile(dir / "tokenizer.json"); + auto tensors = Safetensors::LoadFile(dir / "model.safetensors"); + return std::unique_ptr( + new Llama3Generator(config, gen_config, special_tokens_map, tok_config, + tok, std::move(tensors))); } } // namespace inference diff --git a/src/inference/generator.h b/src/inference/generator.h index 06c4936..e26bb14 100644 --- a/src/inference/generator.h +++ b/src/inference/generator.h @@ -6,6 +6,9 @@ #include #include +#include "inference/safetensors.h" +#include "json/json.h" + namespace gabby { namespace inference { @@ -23,13 +26,36 @@ struct Request { std::ostream& operator<<(std::ostream& os, const Request& msg); -// TODO: thread-safe class Generator { public: - virtual Message Generate(const Request& req); + virtual ~Generator() = default; + virtual Message Generate(const Request& req) = 0; +}; + +class Llama3Generator : public Generator { +public: + Message Generate(const Request& req) override; static std::unique_ptr LoadFromDirectory( std::filesystem::path dir); + +private: + Llama3Generator(json::ValuePtr config, json::ValuePtr gen_config, + json::ValuePtr special_tokens_map, + json::ValuePtr tok_config, json::ValuePtr tok, + Safetensors tensors) + : config_(config), + gen_config_(gen_config), + special_tokens_map_(special_tokens_map), + tok_config_(tok_config), + tok_(tok), + tensors_(std::move(tensors)) {} + json::ValuePtr config_; + json::ValuePtr gen_config_; + json::ValuePtr special_tokens_map_; + json::ValuePtr tok_config_; + json::ValuePtr tok_; + Safetensors tensors_; }; } // namespace inference diff --git a/src/inference/safetensors.cc b/src/inference/safetensors.cc new file mode 100644 index 0000000..1eff3ad --- /dev/null +++ b/src/inference/safetensors.cc @@ -0,0 +1,39 @@ +#include "inference/safetensors.h" + +#include + +#include +#include +#include + +#include "json/json.h" +#include "json/parser.h" +#include "utils/logging.h" + +namespace gabby { +namespace inference { + +/* static */ +Safetensors Safetensors::LoadFile(const std::filesystem::path& path) { + // format: https://github.com/huggingface/safetensors + size_t file_size = std::filesystem::file_size(path); + OwnedFd fd = Open(path.c_str(), O_RDONLY); + OwnedMmap data = Mmap(file_size, std::move(fd)); + + // get header size + uint64_t header_size = 0; + for (int i = 0; i < 8; i++) { + header_size |= data.get()[i] << (8 * i); + } + LOG(DEBUG) << "header size: " << header_size; + + // get header + std::string s(data.get() + 8, data.get() + 8 + header_size); + json::ValuePtr header = json::Parse(s); + LOG(DEBUG) << "header: " << *header; + + return Safetensors(std::move(data), header, 8 + header_size); +} + +} // namespace inference +} // namespace gabby diff --git a/src/inference/safetensors.h b/src/inference/safetensors.h new file mode 100644 index 0000000..096ed88 --- /dev/null +++ b/src/inference/safetensors.h @@ -0,0 +1,30 @@ +#ifndef GABBY_INFERENCE_SAFETENSORS_H_ +#define GABBY_INFERENCE_SAFETENSORS_H_ + +#include +#include + +#include "json/json.h" +#include "utils/pointers.h" + +namespace gabby { +namespace inference { + +class Safetensors { +public: + static Safetensors LoadFile(const std::filesystem::path& path); + const json::ValuePtr header() const { return header_; } + +private: + Safetensors(OwnedMmap mem, json::ValuePtr header, size_t data_offset) + : mem_(std::move(mem)), header_(header), data_offset_(data_offset) {} + + OwnedMmap mem_; + json::ValuePtr header_; + size_t data_offset_ = 0; +}; + +} // namespace inference +} // namespace gabby + +#endif // GABBY_INFERENCE_SAFETENSORS_H_ diff --git a/src/json/parser.cc b/src/json/parser.cc index 1248540..6e4ec62 100644 --- a/src/json/parser.cc +++ b/src/json/parser.cc @@ -6,6 +6,7 @@ #include #include "utils/logging.h" +#include "utils/pointers.h" namespace gabby { namespace json { @@ -107,9 +108,15 @@ std::optional Scanner::Scan() { case '"': { Advance(); std::string s; + bool escaped = false; while (true) { c = Advance(); - if (c == '"' || c == '\n') break; + if ((c == '"' && !escaped) || c == '\n') break; + if (c == '\\' && !escaped) { + escaped = true; + continue; + } + escaped = false; s.push_back(c); } if (c != '"') throw ParsingError("unterminated string"); @@ -268,5 +275,10 @@ ValuePtr Parse(const std::string& s) { return Parse(f.get(), s.size()); } +ValuePtr ParseFile(const std::filesystem::path& path) { + OwnedStream f = Fopen(path.c_str(), "r"); + return Parse(f.get(), std::filesystem::file_size(path)); +} + } // namespace json } // namespace gabby diff --git a/src/json/parser.h b/src/json/parser.h index 8277180..59d6611 100644 --- a/src/json/parser.h +++ b/src/json/parser.h @@ -2,6 +2,7 @@ #define GABBY_JSON_PARSER_H_ #include +#include #include #include #include @@ -66,6 +67,7 @@ class Parser { std::optional lookahead_; }; +ValuePtr ParseFile(const std::filesystem::path& path); ValuePtr Parse(FILE* f, int size); ValuePtr Parse(const std::string& s); diff --git a/src/json/parser_test.cc b/src/json/parser_test.cc index 668cac8..d47cd6c 100644 --- a/src/json/parser_test.cc +++ b/src/json/parser_test.cc @@ -6,11 +6,13 @@ namespace gabby { namespace json { -TEST(JSON, ParseNull) { - ScopedLogLevel scope(LogLevel::DEBUG); - EXPECT_EQ(*Value::Nil(), *Parse("null")); +TEST(JSON, ParseEscapes) { + EXPECT_EQ(*Value::String(R"(""")"), *Parse(R"("\"\"\"")")); + EXPECT_EQ(*Value::String(R"(\\")"), *Parse(R"("\\\\\"")")); } +TEST(JSON, ParseNull) { EXPECT_EQ(*Value::Nil(), *Parse("null")); } + TEST(JSON, ParseNumber) { EXPECT_EQ(*Value::Number(0), *Parse("0")); EXPECT_EQ(*Value::Number(17), *Parse("17")); diff --git a/src/service.cc b/src/service.cc index 1c5a894..3edd3da 100644 --- a/src/service.cc +++ b/src/service.cc @@ -119,7 +119,8 @@ json::ValuePtr MakeResponse(const inference::Message& answer) { InferenceService::InferenceService(Config config) : config_(config), server_(std::make_unique(config_.server_config)), - generator_(inference::Generator::LoadFromDirectory(config.model_dir)) {} + generator_( + inference::Llama3Generator::LoadFromDirectory(config.model_dir)) {} InferenceService::InferenceService( std::unique_ptr server, diff --git a/src/utils/pointers.cc b/src/utils/pointers.cc new file mode 100644 index 0000000..0459453 --- /dev/null +++ b/src/utils/pointers.cc @@ -0,0 +1,40 @@ +#include "utils/pointers.h" + +#include +#include + +#include + +#include "utils/logging.h" + +namespace gabby { + +OwnedStream Fopen(const char *name, const char *mode) { + std::unique_ptr f(fopen(name, mode), fclose); + if (f.get() == nullptr) throw SystemError(errno); + return std::move(f); +} + +OwnedFd Own(int fd) { + return OwnedFd(new int(fd), [](int *fdp) { + if (fdp && *fdp >= 0) { + close(*fdp); + delete fdp; + } + }); +} + +OwnedFd Open(const char *name, int flags) { + int fd = open(name, flags); + if (fd < 0) throw SystemError(errno); + return Own(fd); +} + +void MmapDeleter::operator()(uint8_t *p) { munmap(p, size); } + +OwnedMmap Mmap(size_t size, OwnedFd fd) { + void *data = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, *fd.release(), 0); + return OwnedMmap(static_cast(data), MmapDeleter{.size = size}); +} + +} // namespace gabby diff --git a/src/utils/pointers.h b/src/utils/pointers.h new file mode 100644 index 0000000..bfc83fe --- /dev/null +++ b/src/utils/pointers.h @@ -0,0 +1,32 @@ +#ifndef GABBY_UTILS_POINTERS_H_ +#define GABBY_UTILS_POINTERS_H_ + +#include +#include +#include +#include + +namespace gabby { + +using OwnedStream = std::unique_ptr; + +OwnedStream Fopen(const char *name, const char *mode); + +using OwnedFd = std::unique_ptr; + +OwnedFd Own(int fd); + +OwnedFd Open(const char *name, int flags); + +struct MmapDeleter { + size_t size; + void operator()(uint8_t *p); +}; + +using OwnedMmap = std::unique_ptr; + +OwnedMmap Mmap(size_t size, OwnedFd fd); + +} // namespace gabby + +#endif // GABBY_UTILS_POINTERS_H_