Skip to content

Commit

Permalink
parse safetensors, model, tokenizer, generation configs
Browse files Browse the repository at this point in the history
  • Loading branch information
dhconnelly committed Dec 2, 2024
1 parent b277796 commit 94cd02e
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 28 deletions.
15 changes: 3 additions & 12 deletions src/http/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,18 @@ void MustSend(ResponseWriter& resp, StatusCode status) noexcept {
}
}

OwnedFd to_owned(int fd) {
return OwnedFd(new int(fd), [](int* fdp) {
if (fdp && *fdp >= 0) {
close(*fdp);
delete fdp;
}
});
}

OwnedFd ServerSocket() {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) throw SystemError(errno);
int reuse = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
return to_owned(fd);
return Own(fd);
}

std::array<OwnedFd, 2> MakePipe() {
int fds[2];
if (::pipe(fds) < 0) throw SystemError(errno);
return {to_owned(fds[0]), to_owned(fds[1])};
return {Own(fds[0]), Own(fds[1])};
}

class SocketWriter : public ResponseWriter {
Expand Down Expand Up @@ -340,7 +331,7 @@ void HttpServer::Accept() {
}
char ip[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &client_addr.sin_addr, ip, INET_ADDRSTRLEN);
HttpServer::Client client{to_owned(client_fd), ntohs(client_addr.sin_port),
HttpServer::Client client{Own(client_fd), ntohs(client_addr.sin_port),
std::string(ip)};

// hand it off to a worker thread
Expand Down
3 changes: 1 addition & 2 deletions src/http/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "http/thread_pool.h"
#include "http/types.h"
#include "utils/pointers.h"

namespace gabby {
namespace http {
Expand All @@ -23,8 +24,6 @@ struct ServerConfig {

std::ostream& operator<<(std::ostream& os, const ServerConfig& config);

using OwnedFd = std::unique_ptr<int, void (*)(int*)>;

class HttpServer {
public:
explicit HttpServer(const ServerConfig& config);
Expand Down
22 changes: 15 additions & 7 deletions src/inference/generator.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "inference/generator.h"

#include <cerrno>
#include <cstdio>
#include <format>
#include <sstream>

#include "json/parser.h"
#include "utils/logging.h"

namespace gabby {
Expand All @@ -27,20 +30,25 @@ std::ostream& operator<<(std::ostream& os, const Request& msg) {
to_string(msg.user_message));
}

Message Generator::Generate(const Request& req) {
Message Llama3Generator::Generate(const Request& req) {
return Message{
.role = "assistant",
.content = "hey this is gabby, how are u",
};
}

/* static */ std::unique_ptr<Generator> Generator::LoadFromDirectory(
/* static */
std::unique_ptr<Generator> Llama3Generator::LoadFromDirectory(
std::filesystem::path dir) {
fs::directory_iterator contents(dir);
for (fs::path file : contents) {
LOG(DEBUG) << "scanning: " << file.string();
}
return std::make_unique<Generator>();
auto config = json::ParseFile(dir / "config.json");
auto gen_config = json::ParseFile(dir / "generation_config.json");
auto special_tokens_map = json::ParseFile(dir / "special_tokens_map.json");
auto tok_config = json::ParseFile(dir / "tokenizer_config.json");
auto tok = json::ParseFile(dir / "tokenizer.json");
auto tensors = Safetensors::LoadFile(dir / "model.safetensors");
return std::unique_ptr<Generator>(
new Llama3Generator(config, gen_config, special_tokens_map, tok_config,
tok, std::move(tensors)));
}

} // namespace inference
Expand Down
30 changes: 28 additions & 2 deletions src/inference/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <ostream>
#include <string>

#include "inference/safetensors.h"
#include "json/json.h"

namespace gabby {
namespace inference {

Expand All @@ -23,13 +26,36 @@ struct Request {

std::ostream& operator<<(std::ostream& os, const Request& msg);

// TODO: thread-safe
class Generator {
public:
virtual Message Generate(const Request& req);
virtual ~Generator() = default;
virtual Message Generate(const Request& req) = 0;
};

class Llama3Generator : public Generator {
public:
Message Generate(const Request& req) override;

static std::unique_ptr<Generator> LoadFromDirectory(
std::filesystem::path dir);

private:
Llama3Generator(json::ValuePtr config, json::ValuePtr gen_config,
json::ValuePtr special_tokens_map,
json::ValuePtr tok_config, json::ValuePtr tok,
Safetensors tensors)
: config_(config),
gen_config_(gen_config),
special_tokens_map_(special_tokens_map),
tok_config_(tok_config),
tok_(tok),
tensors_(std::move(tensors)) {}
json::ValuePtr config_;
json::ValuePtr gen_config_;
json::ValuePtr special_tokens_map_;
json::ValuePtr tok_config_;
json::ValuePtr tok_;
Safetensors tensors_;
};

} // namespace inference
Expand Down
39 changes: 39 additions & 0 deletions src/inference/safetensors.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "inference/safetensors.h"

#include <fcntl.h>

#include <cstdint>
#include <cstdio>
#include <memory>

#include "json/json.h"
#include "json/parser.h"
#include "utils/logging.h"

namespace gabby {
namespace inference {

/* static */
Safetensors Safetensors::LoadFile(const std::filesystem::path& path) {
// format: https://github.com/huggingface/safetensors
size_t file_size = std::filesystem::file_size(path);
OwnedFd fd = Open(path.c_str(), O_RDONLY);
OwnedMmap data = Mmap(file_size, std::move(fd));

// get header size
uint64_t header_size = 0;
for (int i = 0; i < 8; i++) {
header_size |= data.get()[i] << (8 * i);
}
LOG(DEBUG) << "header size: " << header_size;

// get header
std::string s(data.get() + 8, data.get() + 8 + header_size);
json::ValuePtr header = json::Parse(s);
LOG(DEBUG) << "header: " << *header;

return Safetensors(std::move(data), header, 8 + header_size);
}

} // namespace inference
} // namespace gabby
30 changes: 30 additions & 0 deletions src/inference/safetensors.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef GABBY_INFERENCE_SAFETENSORS_H_
#define GABBY_INFERENCE_SAFETENSORS_H_

#include <filesystem>
#include <stdexcept>

#include "json/json.h"
#include "utils/pointers.h"

namespace gabby {
namespace inference {

class Safetensors {
public:
static Safetensors LoadFile(const std::filesystem::path& path);
const json::ValuePtr header() const { return header_; }

private:
Safetensors(OwnedMmap mem, json::ValuePtr header, size_t data_offset)
: mem_(std::move(mem)), header_(header), data_offset_(data_offset) {}

OwnedMmap mem_;
json::ValuePtr header_;
size_t data_offset_ = 0;
};

} // namespace inference
} // namespace gabby

#endif // GABBY_INFERENCE_SAFETENSORS_H_
14 changes: 13 additions & 1 deletion src/json/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <format>

#include "utils/logging.h"
#include "utils/pointers.h"

namespace gabby {
namespace json {
Expand Down Expand Up @@ -107,9 +108,15 @@ std::optional<Token> Scanner::Scan() {
case '"': {
Advance();
std::string s;
bool escaped = false;
while (true) {
c = Advance();
if (c == '"' || c == '\n') break;
if ((c == '"' && !escaped) || c == '\n') break;
if (c == '\\' && !escaped) {
escaped = true;
continue;
}
escaped = false;
s.push_back(c);
}
if (c != '"') throw ParsingError("unterminated string");
Expand Down Expand Up @@ -268,5 +275,10 @@ ValuePtr Parse(const std::string& s) {
return Parse(f.get(), s.size());
}

ValuePtr ParseFile(const std::filesystem::path& path) {
OwnedStream f = Fopen(path.c_str(), "r");
return Parse(f.get(), std::filesystem::file_size(path));
}

} // namespace json
} // namespace gabby
2 changes: 2 additions & 0 deletions src/json/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define GABBY_JSON_PARSER_H_

#include <cstdio>
#include <filesystem>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -66,6 +67,7 @@ class Parser {
std::optional<Token> lookahead_;
};

ValuePtr ParseFile(const std::filesystem::path& path);
ValuePtr Parse(FILE* f, int size);
ValuePtr Parse(const std::string& s);

Expand Down
8 changes: 5 additions & 3 deletions src/json/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
namespace gabby {
namespace json {

TEST(JSON, ParseNull) {
ScopedLogLevel scope(LogLevel::DEBUG);
EXPECT_EQ(*Value::Nil(), *Parse("null"));
TEST(JSON, ParseEscapes) {
EXPECT_EQ(*Value::String(R"(""")"), *Parse(R"("\"\"\"")"));
EXPECT_EQ(*Value::String(R"(\\")"), *Parse(R"("\\\\\"")"));
}

TEST(JSON, ParseNull) { EXPECT_EQ(*Value::Nil(), *Parse("null")); }

TEST(JSON, ParseNumber) {
EXPECT_EQ(*Value::Number(0), *Parse("0"));
EXPECT_EQ(*Value::Number(17), *Parse("17"));
Expand Down
3 changes: 2 additions & 1 deletion src/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ json::ValuePtr MakeResponse(const inference::Message& answer) {
InferenceService::InferenceService(Config config)
: config_(config),
server_(std::make_unique<http::HttpServer>(config_.server_config)),
generator_(inference::Generator::LoadFromDirectory(config.model_dir)) {}
generator_(
inference::Llama3Generator::LoadFromDirectory(config.model_dir)) {}

InferenceService::InferenceService(
std::unique_ptr<http::HttpServer> server,
Expand Down
40 changes: 40 additions & 0 deletions src/utils/pointers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "utils/pointers.h"

#include <fcntl.h>
#include <sys/mman.h>

#include <cerrno>

#include "utils/logging.h"

namespace gabby {

OwnedStream Fopen(const char *name, const char *mode) {
std::unique_ptr<FILE, decltype(&fclose)> f(fopen(name, mode), fclose);
if (f.get() == nullptr) throw SystemError(errno);
return std::move(f);
}

OwnedFd Own(int fd) {
return OwnedFd(new int(fd), [](int *fdp) {
if (fdp && *fdp >= 0) {
close(*fdp);
delete fdp;
}
});
}

OwnedFd Open(const char *name, int flags) {
int fd = open(name, flags);
if (fd < 0) throw SystemError(errno);
return Own(fd);
}

void MmapDeleter::operator()(uint8_t *p) { munmap(p, size); }

OwnedMmap Mmap(size_t size, OwnedFd fd) {
void *data = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, *fd.release(), 0);
return OwnedMmap(static_cast<uint8_t *>(data), MmapDeleter{.size = size});
}

} // namespace gabby
32 changes: 32 additions & 0 deletions src/utils/pointers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef GABBY_UTILS_POINTERS_H_
#define GABBY_UTILS_POINTERS_H_

#include <cstdint>
#include <cstdio>
#include <memory>
#include <string_view>

namespace gabby {

using OwnedStream = std::unique_ptr<FILE, decltype(&fclose)>;

OwnedStream Fopen(const char *name, const char *mode);

using OwnedFd = std::unique_ptr<int, void (*)(int *)>;

OwnedFd Own(int fd);

OwnedFd Open(const char *name, int flags);

struct MmapDeleter {
size_t size;
void operator()(uint8_t *p);
};

using OwnedMmap = std::unique_ptr<uint8_t, MmapDeleter>;

OwnedMmap Mmap(size_t size, OwnedFd fd);

} // namespace gabby

#endif // GABBY_UTILS_POINTERS_H_

0 comments on commit 94cd02e

Please sign in to comment.