-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e271597
commit 648714b
Showing
11 changed files
with
200 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#include "inference/config.h" | ||
|
||
#include "json/parser.h" | ||
#include "utils/logging.h" | ||
|
||
namespace gabby { | ||
namespace inference { | ||
|
||
namespace fs = std::filesystem; | ||
|
||
std::unique_ptr<InferenceConfig> LoadConfig(const std::filesystem::path& dir) { | ||
LOG(DEBUG) << "loading model from: " << dir; | ||
auto config = json::ParseFile(dir / "config.json"); | ||
auto gen_config = json::ParseFile(dir / "generation_config.json"); | ||
auto special_tokens_map = json::ParseFile(dir / "special_tokens_map.json"); | ||
auto tok_config = json::ParseFile(dir / "tokenizer_config.json"); | ||
auto tok = json::ParseFile(dir / "tokenizer.json"); | ||
auto tensors = Safetensors::LoadFile(dir / "model.safetensors"); | ||
LOG(DEBUG) << "successfully loaded model"; | ||
return std::unique_ptr<InferenceConfig>(new InferenceConfig{ | ||
.config = config, | ||
.gen_config = gen_config, | ||
.special_tokens_map = special_tokens_map, | ||
.tok_config = tok_config, | ||
.tok = tok, | ||
.tensors = std::move(tensors), | ||
}); | ||
} | ||
|
||
constexpr std::string_view kUserRelativeSnapshotDir = | ||
".cache/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/" | ||
"snapshots"; | ||
|
||
fs::path FindDefaultModelDir() { | ||
const char* home = getenv("HOME"); | ||
if (home == nullptr) throw std::runtime_error("env var HOME is unset"); | ||
|
||
fs::path snapshots_dir(home); | ||
snapshots_dir /= kUserRelativeSnapshotDir; | ||
fs::directory_iterator contents; | ||
try { | ||
contents = fs::directory_iterator(snapshots_dir); | ||
} catch (fs::filesystem_error& err) { | ||
throw std::runtime_error(std::format("can't access model dir at {}: {}", | ||
snapshots_dir.string(), | ||
err.what())); | ||
} | ||
|
||
auto it = fs::begin(contents); | ||
if (it == fs::end(contents)) { | ||
throw std::runtime_error( | ||
std::format("no snapshots found in {}", snapshots_dir.string())); | ||
} | ||
|
||
return *it; | ||
} | ||
} // namespace inference | ||
} // namespace gabby |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#ifndef GABBY_INFERENCE_CONFIG_H_ | ||
#define GABBY_INFERENCE_CONFIG_H_ | ||
|
||
#include <filesystem> | ||
|
||
#include "inference/safetensors.h" | ||
#include "json/json.h" | ||
|
||
namespace gabby { | ||
namespace inference { | ||
|
||
struct InferenceConfig { | ||
json::ValuePtr config; | ||
json::ValuePtr gen_config; | ||
json::ValuePtr special_tokens_map; | ||
json::ValuePtr tok_config; | ||
json::ValuePtr tok; | ||
Safetensors tensors; | ||
}; | ||
|
||
std::unique_ptr<InferenceConfig> LoadConfig( | ||
const std::filesystem::path& directory); | ||
|
||
std::filesystem::path FindDefaultModelDir(); | ||
|
||
} // namespace inference | ||
} // namespace gabby | ||
|
||
#endif // GABBY_INFERENCE_CONFIG_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include "inference/tokenizer.h" | ||
|
||
namespace gabby { | ||
namespace inference { | ||
|
||
std::vector<int> Tokenizer::Tokenize(const std::string_view input) { | ||
return {}; | ||
} | ||
|
||
Tokenizer::Tokenizer(json::ValuePtr special_tokens_map, | ||
json::ValuePtr tokenizer_config, json::ValuePtr tokens) {} | ||
|
||
} // namespace inference | ||
} // namespace gabby |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef GABBY_INFERENCE_TOKENIZER_H_ | ||
#define GABBY_INFERENCE_TOKENIZER_H_ | ||
|
||
#include <memory> | ||
|
||
#include "json/json.h" | ||
|
||
namespace gabby { | ||
namespace inference { | ||
|
||
class Tokenizer { | ||
public: | ||
virtual ~Tokenizer() = default; | ||
|
||
Tokenizer(json::ValuePtr special_tokens_map, | ||
json::ValuePtr tokenizer_config, json::ValuePtr tokens); | ||
|
||
virtual std::vector<int> Tokenize(const std::string_view input); | ||
|
||
private: | ||
}; | ||
|
||
} // namespace inference | ||
} // namespace gabby | ||
|
||
#endif // GABBY_INFERENCE_TOKENIZER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#include "inference/tokenizer.h" | ||
|
||
#include "test/env.h" | ||
#include "test/test.h" | ||
|
||
namespace gabby { | ||
namespace inference { | ||
|
||
TEST(Tokenizer, Empty) { | ||
const auto& config = GlobalConfig(); | ||
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok); | ||
EXPECT_EQ(std::vector<int>{}, tok.Tokenize("")); | ||
} | ||
|
||
TEST(Tokenizer, Tokenize) { | ||
const auto& config = GlobalConfig(); | ||
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok); | ||
EXPECT_EQ(std::vector<int>{}, tok.Tokenize("")); | ||
} | ||
|
||
TEST(Tokenizer, SpecialTokens) { | ||
const auto& config = GlobalConfig(); | ||
Tokenizer tok(config.special_tokens_map, config.tok_config, config.tok); | ||
EXPECT_EQ(std::vector<int>{}, tok.Tokenize("")); | ||
} | ||
|
||
} // namespace inference | ||
} // namespace gabby |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#ifndef GABBY_TEST_ENV_H_ | ||
#define GABBY_TEST_ENV_H_ | ||
|
||
#include "inference/config.h" | ||
|
||
namespace gabby { | ||
|
||
extern const inference::InferenceConfig& GlobalConfig(); | ||
|
||
} // namespace gabby | ||
|
||
#endif // GABBY_TEST_ENV_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters