Skip to content

Commit

Permalink
tokenizer boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
dhconnelly committed Dec 4, 2024
1 parent e271597 commit 648714b
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 60 deletions.
58 changes: 58 additions & 0 deletions src/inference/config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "inference/config.h"

#include "json/parser.h"
#include "utils/logging.h"

namespace gabby {
namespace inference {

namespace fs = std::filesystem;

std::unique_ptr<InferenceConfig> LoadConfig(const std::filesystem::path& dir) {
LOG(DEBUG) << "loading model from: " << dir;
auto config = json::ParseFile(dir / "config.json");
auto gen_config = json::ParseFile(dir / "generation_config.json");
auto special_tokens_map = json::ParseFile(dir / "special_tokens_map.json");
auto tok_config = json::ParseFile(dir / "tokenizer_config.json");
auto tok = json::ParseFile(dir / "tokenizer.json");
auto tensors = Safetensors::LoadFile(dir / "model.safetensors");
LOG(DEBUG) << "successfully loaded model";
return std::unique_ptr<InferenceConfig>(new InferenceConfig{
.config = config,
.gen_config = gen_config,
.special_tokens_map = special_tokens_map,
.tok_config = tok_config,
.tok = tok,
.tensors = std::move(tensors),
});
}

constexpr std::string_view kUserRelativeSnapshotDir =
".cache/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/"
"snapshots";

fs::path FindDefaultModelDir() {
const char* home = getenv("HOME");
if (home == nullptr) throw std::runtime_error("env var HOME is unset");

fs::path snapshots_dir(home);
snapshots_dir /= kUserRelativeSnapshotDir;
fs::directory_iterator contents;
try {
contents = fs::directory_iterator(snapshots_dir);
} catch (fs::filesystem_error& err) {
throw std::runtime_error(std::format("can't access model dir at {}: {}",
snapshots_dir.string(),
err.what()));
}

auto it = fs::begin(contents);
if (it == fs::end(contents)) {
throw std::runtime_error(
std::format("no snapshots found in {}", snapshots_dir.string()));
}

return *it;
}
} // namespace inference
} // namespace gabby
29 changes: 29 additions & 0 deletions src/inference/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef GABBY_INFERENCE_CONFIG_H_
#define GABBY_INFERENCE_CONFIG_H_

#include <filesystem>

#include "inference/safetensors.h"
#include "json/json.h"

namespace gabby {
namespace inference {

struct InferenceConfig {
json::ValuePtr config;
json::ValuePtr gen_config;
json::ValuePtr special_tokens_map;
json::ValuePtr tok_config;
json::ValuePtr tok;
Safetensors tensors;
};

std::unique_ptr<InferenceConfig> LoadConfig(
const std::filesystem::path& directory);

std::filesystem::path FindDefaultModelDir();

} // namespace inference
} // namespace gabby

#endif // GABBY_INFERENCE_CONFIG_H_
14 changes: 3 additions & 11 deletions src/inference/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,9 @@ Message Llama3Generator::Generate(const Request& req) {
}

/* static */
std::unique_ptr<Generator> Llama3Generator::LoadFromDirectory(
std::filesystem::path dir) {
auto config = json::ParseFile(dir / "config.json");
auto gen_config = json::ParseFile(dir / "generation_config.json");
auto special_tokens_map = json::ParseFile(dir / "special_tokens_map.json");
auto tok_config = json::ParseFile(dir / "tokenizer_config.json");
auto tok = json::ParseFile(dir / "tokenizer.json");
auto tensors = Safetensors::LoadFile(dir / "model.safetensors");
return std::unique_ptr<Generator>(
new Llama3Generator(config, gen_config, special_tokens_map, tok_config,
tok, std::move(tensors)));
std::unique_ptr<Generator> Llama3Generator::Load(
std::unique_ptr<InferenceConfig> config) {
return std::unique_ptr<Generator>(new Llama3Generator(std::move(config)));
}

} // namespace inference
Expand Down
24 changes: 6 additions & 18 deletions src/inference/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ostream>
#include <string>

#include "inference/config.h"
#include "inference/safetensors.h"
#include "json/json.h"

Expand Down Expand Up @@ -36,26 +37,13 @@ class Llama3Generator : public Generator {
public:
Message Generate(const Request& req) override;

static std::unique_ptr<Generator> LoadFromDirectory(
std::filesystem::path dir);
static std::unique_ptr<Generator> Load(
std::unique_ptr<InferenceConfig> config);

private:
Llama3Generator(json::ValuePtr config, json::ValuePtr gen_config,
json::ValuePtr special_tokens_map,
json::ValuePtr tok_config, json::ValuePtr tok,
Safetensors tensors)
: config_(config),
gen_config_(gen_config),
special_tokens_map_(special_tokens_map),
tok_config_(tok_config),
tok_(tok),
tensors_(std::move(tensors)) {}
json::ValuePtr config_;
json::ValuePtr gen_config_;
json::ValuePtr special_tokens_map_;
json::ValuePtr tok_config_;
json::ValuePtr tok_;
Safetensors tensors_;
Llama3Generator(std::unique_ptr<InferenceConfig> config)
: config_(std::move(config)) {}
std::unique_ptr<InferenceConfig> config_;
};

} // namespace inference
Expand Down
14 changes: 14 additions & 0 deletions src/inference/tokenizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "inference/tokenizer.h"

namespace gabby {
namespace inference {

std::vector<int> Tokenizer::Tokenize(const std::string_view input) {
return {};
}

Tokenizer::Tokenizer(json::ValuePtr special_tokens_map,
json::ValuePtr tokenizer_config, json::ValuePtr tokens) {}

} // namespace inference
} // namespace gabby
26 changes: 26 additions & 0 deletions src/inference/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef GABBY_INFERENCE_TOKENIZER_H_
#define GABBY_INFERENCE_TOKENIZER_H_

#include <memory>

#include "json/json.h"

namespace gabby {
namespace inference {

class Tokenizer {
public:
virtual ~Tokenizer() = default;

Tokenizer(json::ValuePtr special_tokens_map,
json::ValuePtr tokenizer_config, json::ValuePtr tokens);

virtual std::vector<int> Tokenize(const std::string_view input);

private:
};

} // namespace inference
} // namespace gabby

#endif // GABBY_INFERENCE_TOKENIZER_H_
28 changes: 28 additions & 0 deletions src/inference/tokenizer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "inference/tokenizer.h"

#include "test/env.h"
#include "test/test.h"

namespace gabby {
namespace inference {

TEST(Tokenizer, Empty) {
const auto& config = GlobalConfig();
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok);
EXPECT_EQ(std::vector<int>{}, tok.Tokenize(""));
}

TEST(Tokenizer, Tokenize) {
const auto& config = GlobalConfig();
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok);
EXPECT_EQ(std::vector<int>{}, tok.Tokenize(""));
}

TEST(Tokenizer, SpecialTokens) {
const auto& config = GlobalConfig();
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok);
EXPECT_EQ(std::vector<int>{}, tok.Tokenize(""));
}

} // namespace inference
} // namespace gabby
31 changes: 2 additions & 29 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string_view>
#include <thread>

#include "inference/config.h"
#include "service.h"
#include "utils/logging.h"

Expand All @@ -27,34 +28,6 @@ std::ostream& operator<<(std::ostream& os, const Config& config) {
<< " }";
}

constexpr std::string_view kUserRelativeSnapshotDir =
".cache/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/"
"snapshots";

fs::path FindModelDir() {
const char* home = getenv("HOME");
if (home == nullptr) {
Die("env var HOME is unset");
}

fs::path snapshots_dir(home);
snapshots_dir /= kUserRelativeSnapshotDir;
fs::directory_iterator contents;
try {
contents = fs::directory_iterator(snapshots_dir);
} catch (fs::filesystem_error& err) {
Die(std::format("can't access model dir at {}: {}",
snapshots_dir.string(), err.what()));
}

auto it = fs::begin(contents);
if (it == fs::end(contents)) {
Die(std::format("no snapshots found in {}", snapshots_dir.string()));
}

return *it;
}

Config DefaultConfig() {
return Config{
.log_level = LogLevel::OFF,
Expand Down Expand Up @@ -113,7 +86,7 @@ Config ParseConfig(int argc, char* argv[]) {
}
}
if (config.model_dir.empty()) {
config.model_dir = FindModelDir();
config.model_dir = inference::FindDefaultModelDir();
}
return config;
}
Expand Down
5 changes: 3 additions & 2 deletions src/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <unordered_set>

#include "http/router.h"
#include "inference/config.h"
#include "json/json.h"
#include "json/parser.h"
#include "utils/logging.h"
Expand Down Expand Up @@ -119,8 +120,8 @@ json::ValuePtr MakeResponse(const inference::Message& answer) {
InferenceService::InferenceService(Config config)
: config_(config),
server_(std::make_unique<http::HttpServer>(config_.server_config)),
generator_(
inference::Llama3Generator::LoadFromDirectory(config.model_dir)) {}
generator_(inference::Llama3Generator::Load(
inference::LoadConfig(config.model_dir))) {}

InferenceService::InferenceService(
std::unique_ptr<http::HttpServer> server,
Expand Down
12 changes: 12 additions & 0 deletions src/test/env.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef GABBY_TEST_ENV_H_
#define GABBY_TEST_ENV_H_

#include "inference/config.h"

namespace gabby {

extern const inference::InferenceConfig& GlobalConfig();

} // namespace gabby

#endif // GABBY_TEST_ENV_H_
19 changes: 19 additions & 0 deletions src/test/test_main.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
#include <mutex>

#include "inference/config.h"
#include "test/env.h"
#include "test/test.h"

namespace gabby {

static inference::InferenceConfig* kGlobalConfig = nullptr;
std::once_flag once;

const inference::InferenceConfig& GlobalConfig() {
std::call_once(once, [] {
kGlobalConfig =
inference::LoadConfig(inference::FindDefaultModelDir()).release();
});
return *kGlobalConfig;
}

} // namespace gabby

int main(int argc, char* argv[]) {
int failures = 0;
std::cout << "running " << gabby::kTestCases->size() << " tests\n";
Expand Down

0 comments on commit 648714b

Please sign in to comment.