diff --git a/cpp/README.md b/cpp/README.md index 36160464fd..3a8454b501 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -12,7 +12,7 @@ python ts_scripts/install_dependencies.py --cpp [--cuda=cu121|cu118] ### Building the backend ``` ## Dev Build -cd serve/cpp +cd serve/cpp ./build.sh [-g cu121|cu118] ## Install TorchServe from source @@ -34,32 +34,60 @@ cd serve torchserve torchserve --ncs --start --model-store model_store ``` ## Backend -TorchServe cpp backend can run as a process, which is similar to [TorchServe Python backend](https://github.com/pytorch/serve/tree/master/ts). By default, TorchServe supports torch scripted model in cpp backend. [src/backends/core/backend.hh](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/core/backend.hh) defines the APIs of backend to support multiple different platforms such as MxNet, ONNX and so on. -* [Backend](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/core/backend.hh#L60) defines function `LoadModelInternal` to support model loading on different platforms. -* [ModelInstance](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/core/backend.hh#L25) represents a model copy. The function `Predict` is to support prediction on different platforms. -### TorchScripted Backend -By default, TorchServe cpp provides [TorchScripted backend](https://github.com/pytorch/serve/tree/cpp_backend/cpp/src/backends/torch_scripted). Its [base handler](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh) defines APIs to customize handler. -* [Initialize](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh#L29) -* [LoadModel](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh#L37) -* [Preprocess](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh#L40) -* [Inference](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh#L46) -* [Postprocess](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/backends/torch_scripted/handler/base_handler.hh#L53) +TorchServe cpp backend can run as a process, which is similar to [TorchServe Python backend](https://github.com/pytorch/serve/tree/master/ts). By default, TorchServe supports torch scripted model in cpp backend. Other platforms such as MxNet, ONNX can be supported through custom handlers following the TorchScript example [src/backends/handler/torch_scripted_handler.hh](https://github.com/pytorch/serve/blob/master/src/backends/handler/torch_scripted_handler.hh). +### Custom Handler +By default, TorchServe cpp provides a handler for TorchScript [src/backends/handler/torch_scripted_handler.hh](https://github.com/pytorch/serve/blob/master/src/backends/handler/torch_scripted_handler.hh). Its uses the [BaseHandler](https://github.com/pytorch/serve/blob/master/src/backends/handler/base_handler.hh) which defines the APIs to customize handler. +* [Initialize](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L29) +* [LoadModel](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L37) +* [Preprocess](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L40) +* [Inference](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L46) +* [Postprocess](serve/blob/cpp_backend/cpp/src/backends/handler/base_handler.hh#L53) #### Example -##### Using BaseHandler -* set runtime as "LSP" in model archiver option [--runtime](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) -* set handler as "BaseHandler" in model archiver option [--handler](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) +##### Using TorchScriptHandler +* set runtime as "LSP" in model archiver option [--runtime](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) +* set handler as "TorchScriptHandler" in model archiver option [--handler](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) ``` - torch-model-archiver --model-name mnist_base --version 1.0 --serialized-file mnist_script.pt --handler BaseHandler --runtime LSP + torch-model-archiver --model-name mnist_base --version 1.0 --serialized-file mnist_script.pt --handler TorchScriptHandler --runtime LSP ``` Here is an [example](https://github.com/pytorch/serve/tree/cpp_backend/cpp/test/resources/torchscript_model/mnist/base_handler) of unzipped model mar file. -##### Using customized handler +##### Using Custom Handler * build customized handler shared lib. For example [Mnist handler](https://github.com/pytorch/serve/blob/cpp_backend/cpp/src/examples/image_classifier/mnist). -* set runtime as "LSP" in model archiver option [--runtime](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) +* set runtime as "LSP" in model archiver option [--runtime](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) * set handler as "libmnist_handler:MnistHandler" in model archiver option [--handler](https://github.com/pytorch/serve/tree/master/model-archiver#arguments) ``` torch-model-archiver --model-name mnist_handler --version 1.0 --serialized-file mnist_script.pt --handler libmnist_handler:MnistHandler --runtime LSP ``` Here is an [example](https://github.com/pytorch/serve/tree/cpp_backend/cpp/test/resources/torchscript_model/mnist/mnist_handler) of unzipped model mar file. +##### BabyLLama Example +The babyllama example can be found [here](https://github.com/pytorch/serve/blob/master/cpp/src/examples/babyllama/). +To run the example we need to download the weights as well as tokenizer files: +```bash +wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin +wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin +``` +Subsequently, we need to adjust the paths according to our local file structure in [config.json](https://github.com/pytorch/serve/blob/master/serve/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.json). +```bash +{ +"checkpoint_path" : "/home/ubuntu/serve/cpp/stories15M.bin", +"tokenizer_path" : "/home/ubuntu/serve/cpp/src/examples/babyllama/tokenizer.bin" +} +``` +Then we can create the mar file and deploy it with: +```bash +cd serve/cpp/test/resources/torchscript_model/babyllama/babyllama_handler +torch-model-archiver --model-name llm --version 1.0 --handler libbabyllama_handler:BabyLlamaHandler --runtime LSP --extra-files config.json +mkdir model_store && mv llm.mar model_store/ +torchserve --ncs --start --model-store model_store + +curl -v -X POST "http://localhost:8081/models?initial_workers=1&url=llm.mar" +``` +The handler name `libbabyllama_handler:BabyLlamaHandler` consists of our shared library name (as defined in our [CMakeLists.txt](https://github.com/pytorch/serve/blob/master/serve/cpp/src/examples/CMakeLists.txt)) as well as the class name we chose for our [custom handler class](https://github.com/pytorch/serve/blob/master/serve/cpp/src/examples/babyllama/baby_llama_handler.cc) which derives its properties from BaseHandler. + +To test the model we can run: +```bash +cd serve/cpp/test/resources/torchscript_model/babyllama/ +curl http://localhost:8080/predictions/llm -T prompt.txt +``` ##### Mnist example * Transform data on client side. For example: ``` @@ -75,9 +103,4 @@ image = Image.open("examples/image_classifier/mnist/test_data/0.png") image = image_processing(image) torch.save(image, "0_png.pt") ``` -* Run model registration and prediction: [Using BaseHandler](https://github.com/pytorch/serve/blob/cpp_backend/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L54) or [Using customized handler](https://github.com/pytorch/serve/blob/cpp_backend/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L72). - - - - - +* Run model registration and prediction: [Using BaseHandler](serve/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L54) or [Using customized handler](serve/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc#L72). diff --git a/cpp/build.sh b/cpp/build.sh index ad883abc6a..ca0eecf765 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -212,6 +212,10 @@ function build() { mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so fi + if [ -f "$DEPS_DIR/../src/examples/libbabyllama_handler.so" ]; then + mv $DEPS_DIR/../src/examples/libbabyllama_handler.so $DEPS_DIR/../../test/resources/torchscript_model/babyllama/babyllama_handler/libbabyllama_handler.so + fi + cd $DEPS_DIR/../.. if [ -f "$DEPS_DIR/../test/torchserve_cpp_test" ]; then $DEPS_DIR/../test/torchserve_cpp_test diff --git a/cpp/src/backends/CMakeLists.txt b/cpp/src/backends/CMakeLists.txt index 40a30be339..9824d41f62 100644 --- a/cpp/src/backends/CMakeLists.txt +++ b/cpp/src/backends/CMakeLists.txt @@ -15,40 +15,27 @@ target_link_libraries(ts_backends_protocol PRIVATE ts_utils ${FOLLY_LIBRARIES}) install(TARGETS ts_backends_protocol DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) # build library ts_backend_core -set(TS_BACKENDS_CORE_SOURCE_FILES "") -list(APPEND TS_BACKENDS_CORE_SOURCE_FILES ${TS_BACKENDS_CORE_SRC_DIR}/backend.cc) -add_library(ts_backends_core SHARED ${TS_BACKENDS_CORE_SOURCE_FILES}) +set(BACKEND_SOURCE_FILES "") +list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/core/backend.cc) +list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/core/model_instance.cc) +list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/handler/base_handler.cc) +list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/handler/torch_scripted_handler.cc) +add_library(ts_backends_core SHARED ${BACKEND_SOURCE_FILES}) target_include_directories(ts_backends_core PUBLIC ${TS_BACKENDS_CORE_SRC_DIR}) -target_link_libraries(ts_backends_core PRIVATE ts_utils ts_backends_protocol ${FOLLY_LIBRARIES}) +target_link_libraries(ts_backends_core PUBLIC ts_utils ts_backends_protocol ${FOLLY_LIBRARIES}) install(TARGETS ts_backends_core DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) -# build library ts_backend_torch_scripted -set(TS_BACKENDS_TORCH_SCRIPTED_SOURCE_FILES "") -list(APPEND TS_BACKENDS_TORCH_SCRIPTED_SOURCE_FILES ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR}/torch_scripted_backend.cc) -list(APPEND TS_BACKENDS_TORCH_SCRIPTED_SOURCE_FILES ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR}/handler/base_handler.cc) -add_library(ts_backends_torch_scripted SHARED ${TS_BACKENDS_TORCH_SCRIPTED_SOURCE_FILES}) -target_include_directories(ts_backends_torch_scripted PUBLIC - ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR} ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR}/handler ${TORCH_INCLUDE_DIRS}) -target_link_libraries(ts_backends_torch_scripted PUBLIC ts_utils ts_backends_core ${TORCH_LIBRARIES}) -install(TARGETS ts_backends_torch_scripted DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) - -# build library ts_backend_torch_deploy -#set(TS_BACKENDS_TORCH_DEPLOY_SOURCE_FILES "") -#add_library(ts_backends_torch_deploy SHARED ${TS_BACKENDS_TORCH_DEPLOY_SOURCE_FILES}) -#target_include_directories(ts_backends_torch_deploy PUBLIC ${TS_BACKENDS_TORCH_DEPLOY_SRC_DIR}) -#target_link_libraries(ts_backends_torch_deploy PRIVATE ts_utils ts_backends_core ${TORCH_LIBRARIES}) - # build exe model_worker_socket -add_executable(model_worker_socket +add_executable(model_worker_socket "${TS_BACKENDS_PROCESS_SRC_DIR}/model_worker_socket.cc" "${TS_BACKENDS_PROCESS_SRC_DIR}/model_worker.cc" ) -target_include_directories(model_worker_socket PRIVATE +target_include_directories(model_worker_socket PRIVATE ${TS_BACKENDS_CORE_SRC_DIR} - ${TS_BACKENDS_PROTOCOL_SRC_DIR} - ${TS_BACKENDS_PROCESS_SRC_DIR} - ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR} + ${TS_BACKENDS_PROTOCOL_SRC_DIR} + ${TS_BACKENDS_PROCESS_SRC_DIR} + ${TS_BACKENDS_TORCH_SCRIPTED_SRC_DIR} ) -target_link_libraries(model_worker_socket - PRIVATE ts_backends_core ts_backends_protocol ts_backends_torch_scripted ${FOLLY_LIBRARIES}) +target_link_libraries(model_worker_socket + PRIVATE ts_backends_core ts_backends_protocol ${FOLLY_LIBRARIES} ${TORCH_LIBRARIES}) install(TARGETS model_worker_socket DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/bin) diff --git a/cpp/src/backends/core/backend.cc b/cpp/src/backends/core/backend.cc index 638f79130e..58b3bae05e 100644 --- a/cpp/src/backends/core/backend.cc +++ b/cpp/src/backends/core/backend.cc @@ -1,6 +1,63 @@ #include "src/backends/core/backend.hh" +#include + +#include "src/backends/handler/handler_factory.hh" + namespace torchserve { +Backend::Backend() {} + +Backend::~Backend() { + handler_.reset(); + model_instance_table_.clear(); + // Todo: do proper cleanup + // dl_loader_->CloseDL(); +} + +bool Backend::Initialize(const std::string &model_dir) { + random_generator_.seed(time(0)); + manifest_ = std::make_shared(); + // TODO: windows + if (!manifest_->Initialize( + fmt::format("{}/MAR-INF/MANIFEST.json", model_dir))) { + return false; + } + + LoadHandler(model_dir); + + if (!handler_) { + return false; + } + + handler_->Initialize(model_dir, manifest_); + + return true; +} + +void Backend::LoadHandler(const std::string &model_dir) { + const std::string &handler_str = manifest_->GetModel().handler; + std::size_t delimiter_pos = handler_str.find(manifest_->kHandler_Delimiter); + if (delimiter_pos != std::string::npos) { +#ifdef __APPLE__ + std::string lib_path = fmt::format("{}/{}.dylib", model_dir, + handler_str.substr(0, delimiter_pos)); +#else + std::string lib_path = fmt::format("{}/{}.so", model_dir, + handler_str.substr(0, delimiter_pos)); +#endif + std::string handler_class_name = handler_str.substr(delimiter_pos + 1); + std::string allocator_func = fmt::format("allocator{}", handler_class_name); + std::string deleter_func = fmt::format("deleter{}", handler_class_name); + + dl_loader_ = std::make_unique>( + lib_path, allocator_func, deleter_func); + dl_loader_->OpenDL(); + handler_ = dl_loader_->GetInstance(); + } else { + handler_ = HandlerFactory::GetInstance().createHandler(handler_str); + } +} + std::unique_ptr Backend::LoadModel( std::shared_ptr load_model_request) { /** @@ -13,12 +70,43 @@ std::unique_ptr Backend::LoadModel( * - status_READY: return the model instance if it is already. * * Common steps: - * https://github.com/pytorch/serve/blob/master/ts/model_loader.py#L62 + * serve/blob/master/ts/model_loader.py#L62 */ + // TODO: support request envelope: + // serve/tree/master/ts/torch_handler/request_envelope + return LoadModelInternal(std::move(load_model_request)); } +std::unique_ptr Backend::LoadModelInternal( + std::shared_ptr load_model_request) { + std::string model_instance_id = BuildModelInstanceId(load_model_request); + try { + model_instance_table_[model_instance_id] = { + ModelInstanceStatus::INIT, std::shared_ptr(nullptr)}; + + auto result = handler_->LoadModel(load_model_request); + SetModelInstanceInfo(model_instance_id, ModelInstanceStatus::READY, + std::make_shared( + model_instance_id, std::move(result.first), + handler_, std::move(result.second))); + + ready_model_instance_ids_.emplace_back(model_instance_id); + std::string message = + fmt::format("loaded model {}", load_model_request->model_name); + return std::make_unique( + // TODO: check current response msg content + 200, message); + } catch (const c10::Error &e) { + SetModelInstanceInfo(model_instance_id, ModelInstanceStatus::FAILED, + std::shared_ptr(nullptr)); + return std::make_unique( + // TODO: check existing + 500, e.msg()); + } +} + std::string Backend::BuildModelInstanceId( std::shared_ptr load_model_request) { std::string device_type("cpu"); @@ -30,7 +118,7 @@ std::string Backend::BuildModelInstanceId( } void Backend::SetModelInstanceInfo( - const std::string& model_instance_id, ModelInstanceStatus new_status, + const std::string &model_instance_id, ModelInstanceStatus new_status, std::shared_ptr new_model_instance) { model_instance_table_[model_instance_id].status = new_status; model_instance_table_[model_instance_id].model_instance = @@ -38,7 +126,7 @@ void Backend::SetModelInstanceInfo( } torchserve::Backend::ModelInstanceStatus Backend::GetModelInstanceStatus( - const std::string& model_instance_id) { + const std::string &model_instance_id) { auto model_instance_info = model_instance_table_.find(model_instance_id); if (model_instance_info == model_instance_table_.end()) { return torchserve::Backend::ModelInstanceStatus::NOT_INIT; @@ -47,7 +135,7 @@ torchserve::Backend::ModelInstanceStatus Backend::GetModelInstanceStatus( } std::shared_ptr Backend::GetModelInstance( - const std::string& model_instance_id) { + const std::string &model_instance_id) { auto model_instance_info = model_instance_table_.find(model_instance_id); if (model_instance_info == model_instance_table_.end()) { return std::shared_ptr(nullptr); diff --git a/cpp/src/backends/core/backend.hh b/cpp/src/backends/core/backend.hh index e6e54a1c3f..d7b2a59826 100644 --- a/cpp/src/backends/core/backend.hh +++ b/cpp/src/backends/core/backend.hh @@ -1,5 +1,4 @@ -#ifndef TS_CPP_BACKENDS_CORE_BACKEND_HH_ -#define TS_CPP_BACKENDS_CORE_BACKEND_HH_ +#pragma once #include #include @@ -8,38 +7,19 @@ #include #include #include +#include #include +#include "model_instance.hh" #include "src/utils/config.hh" +#include "src/utils/dl_loader.hh" #include "src/utils/message.hh" namespace torchserve { -/** - * - * @brief TorchServe ModelInstance Interface - * ModelInstance <=> Service: - * https://github.com/pytorch/serve/blob/master/ts/service.py#L21 - */ -class ModelInstance { - public: - ModelInstance(const std::string& instance_id) : instance_id_(instance_id){}; - virtual ~ModelInstance(){}; - - virtual std::shared_ptr Predict( - std::shared_ptr batch) = 0; - - const std::string& GetInstanceId() { return instance_id_; }; - - protected: - // instance_id naming convention: - // device_type + ":" + device_id (or object id) - std::string instance_id_; -}; - /** * @brief TorchServe Backend Interface * Backend <=> ModelLoader: - * https://github.com/pytorch/serve/blob/master/ts/model_loader.py#L28 + * serve/blob/master/ts/model_loader.py#L28 * * Note: * Any framework should implement its own backend which includes: @@ -48,7 +28,6 @@ class ModelInstance { * 3. function std::shared_ptr CreateBackend() * * The core idea: - * - A framework has its own backend in a model server. * - A backend has multiple model instances. * - A worker is associated with one model instance. * - A model instance is one model loaded on CPU or GPU. @@ -60,42 +39,38 @@ class Backend { // NOLINTBEGIN(cppcoreguidelines-pro-type-member-init) struct ModelInstanceInfo { ModelInstanceStatus status; - std::shared_ptr model_instance; + std::shared_ptr model_instance; }; // NOLINTEND(cppcoreguidelines-pro-type-member-init) - Backend() = default; - virtual ~Backend() = default; - - virtual bool Initialize(const std::string& model_dir) { - random_generator_.seed(time(0)); - manifest_ = std::make_shared(); - // TODO: windows - return manifest_->Initialize( - fmt::format("{}/MAR-INF/MANIFEST.json", model_dir)); - }; - - std::unique_ptr LoadModel( - std::shared_ptr load_model_request); + Backend(); + virtual ~Backend(); - virtual std::unique_ptr LoadModelInternal( - std::shared_ptr load_model_request) = 0; + bool Initialize(const std::string &model_dir); ModelInstanceStatus GetModelInstanceStatus( - const std::string& model_instance_id); + const std::string &model_instance_id); std::shared_ptr GetModelInstance( - const std::string& model_instance_id); + const std::string &model_instance_id); std::shared_ptr GetModelInstance(); - void SetModelInstanceInfo( - const std::string& model_instance_id, ModelInstanceStatus new_status, - std::shared_ptr new_model_instance); + void SetModelInstanceInfo(const std::string &model_instance_id, + ModelInstanceStatus new_status, + std::shared_ptr new_model_instance); + + std::unique_ptr LoadModel( + std::shared_ptr load_model_request); protected: std::string BuildModelInstanceId( std::shared_ptr load_model_request); + void LoadHandler(const std::string &model_dir); + + std::unique_ptr LoadModelInternal( + std::shared_ptr load_model_request); + std::shared_ptr manifest_; // key: model_instance_id @@ -106,16 +81,10 @@ class Backend { std::atomic_uint16_t model_instance_count_ = 0; - private: - std::size_t Random(); + std::unique_ptr> dl_loader_; + std::shared_ptr handler_; + std::size_t Random(); std::mt19937 random_generator_; }; - -class ModelWorker { - public: - ModelWorker(){}; - ~ModelWorker(); -}; } // namespace torchserve -#endif // TS_CPP_BACKENDS_CORE_BACKEND_HH_ diff --git a/cpp/src/backends/core/model_instance.cc b/cpp/src/backends/core/model_instance.cc new file mode 100644 index 0000000000..d534252fe7 --- /dev/null +++ b/cpp/src/backends/core/model_instance.cc @@ -0,0 +1,24 @@ +#include "model_instance.hh" + +#include + +namespace torchserve { + +ModelInstance::ModelInstance(const std::string& instance_id, + std::shared_ptr model, + std::shared_ptr& handler, + std::shared_ptr device) + : instance_id_(instance_id), + model_(model), + handler_(handler), + device_(device) {} + +std::shared_ptr ModelInstance::Predict( + std::shared_ptr request_batch) { + auto response_batch = std::make_shared(); + handler_->Handle(model_, device_, request_batch, response_batch); + + return response_batch; +} + +} // namespace torchserve diff --git a/cpp/src/backends/core/model_instance.hh b/cpp/src/backends/core/model_instance.hh new file mode 100644 index 0000000000..0b8ab9c584 --- /dev/null +++ b/cpp/src/backends/core/model_instance.hh @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include + +#include "src/backends/handler/base_handler.hh" + +namespace torchserve { +class ModelInstance { + public: + ModelInstance(const std::string& instance_id, std::shared_ptr model, + std::shared_ptr& handler, + std::shared_ptr device); + virtual ~ModelInstance() = default; + + std::shared_ptr Predict( + std::shared_ptr request_batch); + + protected: + // instance_id naming convention: + // device_type + ":" + device_id (or object id) + std::string instance_id_; + std::shared_ptr model_; + std::shared_ptr handler_; + std::shared_ptr device_; +}; +} // namespace torchserve diff --git a/cpp/src/backends/torch_scripted/handler/base_handler.cc b/cpp/src/backends/handler/base_handler.cc similarity index 83% rename from cpp/src/backends/torch_scripted/handler/base_handler.cc rename to cpp/src/backends/handler/base_handler.cc index e1eef1f80b..ccd25a0ce0 100644 --- a/cpp/src/backends/torch_scripted/handler/base_handler.cc +++ b/cpp/src/backends/handler/base_handler.cc @@ -1,48 +1,26 @@ -#include "src/backends/torch_scripted/handler/base_handler.hh" +#include "base_handler.hh" namespace torchserve { -namespace torchscripted { -std::pair, - std::shared_ptr> -BaseHandler::LoadModel( - std::shared_ptr& load_model_request) { - try { - auto device = GetTorchDevice(load_model_request); - auto module = std::make_shared(torch::jit::load( - // TODO: windows - fmt::format("{}/{}", load_model_request->model_dir, - manifest_->GetModel().serialized_file), - *device)); - return std::make_pair(module, device); - } catch (const c10::Error& e) { - TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", - load_model_request->model_name, load_model_request->gpu_id, - e.msg()); - throw e; - } catch (const std::runtime_error& e) { - TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", - load_model_request->model_name, load_model_request->gpu_id, - e.what()); - throw e; - } -} void BaseHandler::Handle( - std::shared_ptr& model, - std::shared_ptr& device, + std::shared_ptr model, std::shared_ptr& device, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { std::string req_ids = ""; std::map map_idx_to_req_id; std::pair&> idx_to_req_id( req_ids, map_idx_to_req_id); + std::string just_passed = ""; try { auto start_time = std::chrono::system_clock::now(); auto inputs = Preprocess(device, idx_to_req_id, request_batch, response_batch); + just_passed = "Preprocessing"; auto outputs = Inference(model, inputs, device, idx_to_req_id, response_batch); + just_passed = "Inference"; Postprocess(outputs, idx_to_req_id, response_batch); + just_passed = "Postprocessing"; auto stop_time = std::chrono::system_clock::now(); std::chrono::duration duration = stop_time - start_time; try { @@ -70,7 +48,7 @@ void BaseHandler::Handle( TS_LOGF(ERROR, "Failed to record PredictionTime metric. {}", e.what()); } } catch (...) { - TS_LOG(ERROR, "Failed to handle this batch"); + TS_LOG(ERROR, "Failed to handle this batch after: {}", just_passed); } } @@ -89,7 +67,7 @@ std::shared_ptr BaseHandler::GetTorchDevice( load_model_request->gpu_id); } -std::vector BaseHandler::Preprocess( +c10::IValue BaseHandler::Preprocess( std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, @@ -99,7 +77,8 @@ std::vector BaseHandler::Preprocess( * Ref: * https://github.com/pytorch/serve/blob/be5ff32dab0d81ceb1c2a9d42550ed5904ae9282/ts/torch_handler/vision_handler.py#L33 */ - std::vector batch_ivalue; + auto batch_ivalue = c10::impl::GenericList(c10::TensorType::get()); + std::vector batch_tensors; uint8_t idx = 0; for (auto& request : *request_batch) { @@ -187,9 +166,8 @@ std::vector BaseHandler::Preprocess( return batch_ivalue; } -torch::Tensor BaseHandler::Inference( - std::shared_ptr model, - std::vector& inputs, +c10::IValue BaseHandler::Inference( + std::shared_ptr model, c10::IValue& inputs, std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) { @@ -198,7 +176,11 @@ torch::Tensor BaseHandler::Inference( } try { torch::NoGradGuard no_grad; - return model->forward(inputs).toTensor(); + std::shared_ptr jit_model( + std::static_pointer_cast(model)); + std::vector input_vec(inputs.toList().begin(), + inputs.toList().end()); + return jit_model->forward(input_vec).toTensor(); } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to predict, error: {}", e.what()); for (auto& kv : idx_to_req_id.second) { @@ -212,15 +194,16 @@ torch::Tensor BaseHandler::Inference( } void BaseHandler::Postprocess( - const torch::Tensor& data, + c10::IValue& inputs, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) { + auto data = inputs.toTensor(); for (const auto& kv : idx_to_req_id.second) { try { auto response = (*response_batch)[kv.second]; response->SetResponse(200, "data_type", torchserve::PayloadType::kDATA_TYPE_BYTES, - torch::pickle_save(data[kv.first])); + torch::pickle_save(at::IValue(data[kv.first]))); } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", kv.second, e.what()); @@ -239,5 +222,4 @@ void BaseHandler::Postprocess( } } } -} // namespace torchscripted } // namespace torchserve diff --git a/cpp/src/backends/torch_scripted/handler/base_handler.hh b/cpp/src/backends/handler/base_handler.hh similarity index 69% rename from cpp/src/backends/torch_scripted/handler/base_handler.hh rename to cpp/src/backends/handler/base_handler.hh index 3b612a65f9..d0676aa3b1 100644 --- a/cpp/src/backends/torch_scripted/handler/base_handler.hh +++ b/cpp/src/backends/handler/base_handler.hh @@ -1,5 +1,4 @@ -#ifndef TS_CPP_BACKENDS_TORCH_SCRIPTED_HANDLER_BASE_HANDLER_HH_ -#define TS_CPP_BACKENDS_TORCH_SCRIPTED_HANDLER_BASE_HANDLER_HH_ +#pragma once #include #include @@ -17,11 +16,10 @@ #include "src/utils/model_archive.hh" namespace torchserve { -namespace torchscripted { /** * @brief * TorchBaseHandler <=> BaseHandler: - * https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L37 + * serve/ts/torch_handler/base_handler.py#L37 * * TorchBaseHandler is not responsible for loading model since it is derived * from TorchScritpedModelInstance. @@ -39,38 +37,35 @@ class BaseHandler { manifest_ = manifest; }; - virtual std::pair, - std::shared_ptr> - LoadModel(std::shared_ptr& load_model_request); + virtual std::pair, std::shared_ptr> + LoadModel(std::shared_ptr& load_model_request) = 0; - virtual std::vector Preprocess( + virtual c10::IValue Preprocess( std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch); - virtual torch::Tensor Inference( - std::shared_ptr model, - std::vector& inputs, + virtual c10::IValue Inference( + std::shared_ptr model, c10::IValue& inputs, std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch); virtual void Postprocess( - const torch::Tensor& data, + c10::IValue& data, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch); /** * @brief * function Predict <=> entry point function handle - * https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L205 + * /serve/ts/torch_handler/base_handler.py#L205 * @param inference_request * @return std::shared_ptr */ void Handle( - std::shared_ptr& model, - std::shared_ptr& device, + std::shared_ptr model, std::shared_ptr& device, std::shared_ptr& request_batch, std::shared_ptr& response_batch); @@ -81,6 +76,4 @@ class BaseHandler { std::shared_ptr manifest_; std::string model_dir_; }; -} // namespace torchscripted } // namespace torchserve -#endif // TS_CPP_BACKENDS_TORCH_HANDLER_BASE_HANDLER_HH_ diff --git a/cpp/src/backends/torch_scripted/handler/handler_factory.hh b/cpp/src/backends/handler/handler_factory.hh similarity index 59% rename from cpp/src/backends/torch_scripted/handler/handler_factory.hh rename to cpp/src/backends/handler/handler_factory.hh index eb14705ca3..52689cecf1 100644 --- a/cpp/src/backends/torch_scripted/handler/handler_factory.hh +++ b/cpp/src/backends/handler/handler_factory.hh @@ -1,13 +1,12 @@ -#ifndef TS_CPP_BACKENDS_TORCH_SCRIPTED_HANDLER_HANDLER_FACTORY_HH_ -#define TS_CPP_BACKENDS_TORCH_SCRIPTED_HANDLER_HANDLER_FACTORY_HH_ +#pragma once #include #include -#include "src/backends/torch_scripted/handler/base_handler.hh" +#include "src/backends/handler/base_handler.hh" +#include "src/backends/handler/torch_scripted_handler.hh" namespace torchserve { -namespace torchscripted { class HandlerFactory { public: static HandlerFactory GetInstance() { @@ -27,11 +26,9 @@ class HandlerFactory { private: std::map (*)()> handlers_ = { - {"BaseHandler", []() -> std::shared_ptr { - return std::make_shared(); + {"TorchScriptHandler", []() -> std::shared_ptr { + return std::make_shared(); }}}; HandlerFactory(){}; }; -} // namespace torchscripted } // namespace torchserve -#endif // TS_CPP_BACKENDS_TORCH_SCRIPTED_HANDLER_HANDLER_FACTORY_HH_ \ No newline at end of file diff --git a/cpp/src/backends/handler/torch_scripted_handler.cc b/cpp/src/backends/handler/torch_scripted_handler.cc new file mode 100644 index 0000000000..271036ce60 --- /dev/null +++ b/cpp/src/backends/handler/torch_scripted_handler.cc @@ -0,0 +1,33 @@ +#include "src/backends/handler/torch_scripted_handler.hh" + +#include + +#include "src/utils/message.hh" +#include "src/utils/metrics/registry.hh" + +namespace torchserve { +std::pair, std::shared_ptr> +TorchScriptHandler::LoadModel( + std::shared_ptr& load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + std::shared_ptr module( + std::make_shared(torch::jit::load( + // TODO: windows + fmt::format("{}/{}", load_model_request->model_dir, + manifest_->GetModel().serialized_file), + *device))); + return std::make_pair(module, device); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} +} // namespace torchserve diff --git a/cpp/src/backends/handler/torch_scripted_handler.hh b/cpp/src/backends/handler/torch_scripted_handler.hh new file mode 100644 index 0000000000..1e7c816a34 --- /dev/null +++ b/cpp/src/backends/handler/torch_scripted_handler.hh @@ -0,0 +1,10 @@ +#pragma once +#include "base_handler.hh" + +namespace torchserve { + +class TorchScriptHandler : public BaseHandler { + std::pair, std::shared_ptr> LoadModel( + std::shared_ptr& load_model_request) override; +}; +} // namespace torchserve diff --git a/cpp/src/backends/process/model_worker.cc b/cpp/src/backends/process/model_worker.cc index 89c38fdd47..ae6afaeb01 100644 --- a/cpp/src/backends/process/model_worker.cc +++ b/cpp/src/backends/process/model_worker.cc @@ -110,7 +110,7 @@ bool SocketServer::CreateBackend( const torchserve::Manifest::RuntimeType& runtime_type, const std::string& model_dir) { if (runtime_type == "LSP") { - backend_ = std::make_shared(); + backend_ = std::make_shared(); return backend_->Initialize(model_dir); } return false; diff --git a/cpp/src/backends/process/model_worker.hh b/cpp/src/backends/process/model_worker.hh index 4659dcd474..0448fc7238 100644 --- a/cpp/src/backends/process/model_worker.hh +++ b/cpp/src/backends/process/model_worker.hh @@ -1,5 +1,4 @@ -#ifndef TS_CPP_BACKENDS_PROCESS_MODEL_WORKER_HH_ -#define TS_CPP_BACKENDS_PROCESS_MODEL_WORKER_HH_ +#pragma once #include #include @@ -15,7 +14,6 @@ #include "src/backends/core/backend.hh" #include "src/backends/protocol/otf_message.hh" -#include "src/backends/torch_scripted/torch_scripted_backend.hh" #include "src/utils/config.hh" #include "src/utils/logging.hh" #include "src/utils/model_archive.hh" @@ -77,4 +75,3 @@ class SocketModelWorker { std::shared_ptr backend_; }; } // namespace torchserve -#endif // TS_CPP_BACKENDS_PROCESS_MODEL_WORKER_HH_ \ No newline at end of file diff --git a/cpp/src/backends/torch_scripted/torch_scripted_backend.cc b/cpp/src/backends/torch_scripted/torch_scripted_backend.cc deleted file mode 100644 index 87f8498c42..0000000000 --- a/cpp/src/backends/torch_scripted/torch_scripted_backend.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "src/backends/torch_scripted/torch_scripted_backend.hh" - -namespace torchserve { -namespace torchscripted { -bool Backend::Initialize(const std::string& model_dir) { - if (!torchserve::Backend::Initialize(model_dir)) { - return false; - } - LoadHandler(model_dir); - if (!handler_) { - return false; - } - handler_->Initialize(model_dir, manifest_); - - // TODO: support request envelope: - // https://github.com/pytorch/serve/tree/master/ts/torch_handler/request_envelope - return true; -} - -void Backend::LoadHandler(const std::string& model_dir) { - const std::string& handler_str = manifest_->GetModel().handler; - std::size_t delimiter_pos = handler_str.find(manifest_->kHandler_Delimiter); - if (delimiter_pos != std::string::npos) { -#ifdef __APPLE__ - std::string lib_path = fmt::format("{}/{}.dylib", model_dir, - handler_str.substr(0, delimiter_pos)); -#else - std::string lib_path = fmt::format("{}/{}.so", model_dir, - handler_str.substr(0, delimiter_pos)); -#endif - std::string handler_class_name = handler_str.substr(delimiter_pos + 1); - std::string allocator_func = fmt::format("allocator{}", handler_class_name); - std::string deleter_func = fmt::format("deleter{}", handler_class_name); - - dl_loader_ = std::make_unique>( - lib_path, allocator_func, deleter_func); - dl_loader_->OpenDL(); - handler_ = dl_loader_->GetInstance(); - } else { - handler_ = HandlerFactory::GetInstance().createHandler(handler_str); - } -} - -std::unique_ptr Backend::LoadModelInternal( - std::shared_ptr load_model_request) { - std::string model_instance_id = BuildModelInstanceId(load_model_request); - try { - model_instance_table_[model_instance_id] = { - torchserve::Backend::ModelInstanceStatus::INIT, - std::shared_ptr(nullptr)}; - - auto result = handler_->LoadModel(load_model_request); - SetModelInstanceInfo( - model_instance_id, torchserve::Backend::ModelInstanceStatus::READY, - std::make_shared( - model_instance_id, std::move(result.first), handler_, - std::move(result.second))); - - ready_model_instance_ids_.emplace_back(model_instance_id); - std::string message = - fmt::format("loaded model {}", load_model_request->model_name); - return std::make_unique( - // TODO: check current response msg content - 200, message); - } catch (const c10::Error& e) { - SetModelInstanceInfo(model_instance_id, - torchserve::Backend::ModelInstanceStatus::FAILED, - std::shared_ptr(nullptr)); - return std::make_unique( - // TODO: check existing - 500, e.msg()); - } -} - -std::shared_ptr ModelInstance::Predict( - std::shared_ptr request_batch) { - auto response_batch = std::make_shared(); - handler_->Handle(model_, device_, request_batch, response_batch); - - return response_batch; -} -} // namespace torchscripted -} // namespace torchserve \ No newline at end of file diff --git a/cpp/src/backends/torch_scripted/torch_scripted_backend.hh b/cpp/src/backends/torch_scripted/torch_scripted_backend.hh deleted file mode 100644 index 0ca4ffa6d7..0000000000 --- a/cpp/src/backends/torch_scripted/torch_scripted_backend.hh +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef TS_CPP_BACKENDS_TORCH_SCRIPTED_TORCH_SCRIPTED_BACKEND_HH_ -#define TS_CPP_BACKENDS_TORCH_SCRIPTED_TORCH_SCRIPTED_BACKEND_HH_ - -#include -#include -#include - -#include - -#include "src/backends/core/backend.hh" -#include "src/backends/torch_scripted/handler/base_handler.hh" -#include "src/backends/torch_scripted/handler/handler_factory.hh" -#include "src/utils/dl_loader.hh" -#include "src/utils/logging.hh" -#include "src/utils/message.hh" -#include "src/utils/model_archive.hh" - -namespace torchserve { -namespace torchscripted { -class Backend final : public torchserve::Backend { - public: - Backend() = default; - ~Backend() override { - if (dl_loader_ && handler_) { - handler_.reset(); - } - }; - - bool Initialize(const std::string& model_dir) override; - - std::unique_ptr LoadModelInternal( - std::shared_ptr load_model_request) - override; - - private: - void LoadHandler(const std::string& model_dir); - - std::unique_ptr> dl_loader_; - std::shared_ptr handler_; -}; - -class ModelInstance final : public torchserve::ModelInstance { - public: - ModelInstance( - const std::string& instance_id, - std::shared_ptr model, - std::shared_ptr& handler, - std::shared_ptr device) - : torchserve::ModelInstance(instance_id), - model_(model), - handler_(handler), - device_(device){}; - ~ModelInstance() override = default; - - std::shared_ptr Predict( - std::shared_ptr request_batch) - override; - - private: - std::shared_ptr model_; - std::shared_ptr handler_; - std::shared_ptr device_; -}; -} // namespace torchscripted -} // namespace torchserve -#endif // TS_CPP_BACKENDS_TORCH_SCRIPTED_TORCH_SCRIPTED_BACKEND_HH_ \ No newline at end of file diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 4c9c534097..d5402a5faa 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -4,4 +4,13 @@ set(MNIST_SOURCE_FILES "") list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) -target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +target_link_libraries(mnist_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES}) + + +set(BABYLLAMA_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/babyllama") +set(BABYLLAMA_SOURCE_FILES "") +list(APPEND BABYLLAMA_SOURCE_FILES ${BABYLLAMA_SRC_DIR}/baby_llama_handler.cc) +add_library(babyllama_handler SHARED ${BABYLLAMA_SOURCE_FILES}) +target_include_directories(babyllama_handler PUBLIC ${BABYLLAMA_SRC_DIR}) +target_link_libraries(babyllama_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES}) +target_compile_options(babyllama_handler PRIVATE -Wall -Wextra -Ofast) diff --git a/cpp/src/examples/babyllama/baby_llama_handler.cc b/cpp/src/examples/babyllama/baby_llama_handler.cc new file mode 100644 index 0000000000..62980e5f78 --- /dev/null +++ b/cpp/src/examples/babyllama/baby_llama_handler.cc @@ -0,0 +1,304 @@ +#include "src/examples/babyllama/baby_llama_handler.hh" + +#include +#include + +#include + +#include "src/examples/babyllama/llama2.c/run.c" + +namespace llm { + +Transformer transformer; +Tokenizer tokenizer; +Sampler sampler; +int steps = 256; + +std::pair, std::shared_ptr> +BabyLlamaHandler::LoadModel( + std::shared_ptr &load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + + const std::string configFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "config.json"); + std::string jsonContent; + if (!folly::readFile(configFilePath.c_str(), jsonContent)) { + std::cerr << "config.json not found at: " << configFilePath << std::endl; + throw; + } + folly::dynamic json; + json = folly::parseJson(jsonContent); + std::string checkpoint_path; + std::string tokenizer_path; + if (json.find("checkpoint_path") != json.items().end() && + json.find("tokenizer_path") != json.items().end()) { + checkpoint_path = json["checkpoint_path"].asString(); + tokenizer_path = json["tokenizer_path"].asString(); + } else { + std::cerr + << "Required fields 'model_name' and 'model_path' not found in JSON." + << std::endl; + throw; + } + + build_transformer(&transformer, + const_cast(checkpoint_path.c_str())); + + build_tokenizer(&tokenizer, const_cast(tokenizer_path.c_str()), + transformer.config.vocab_size); + + float temperature = + 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, + // but slower + unsigned long long rng_seed(0); + // build the Sampler + build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, + rng_seed); + + return std::make_pair(nullptr, device); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + +c10::IValue BabyLlamaHandler::Preprocess( + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &request_batch, + std::shared_ptr &response_batch) { + auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + std::vector batch_tensors; + uint8_t idx = 0; + for (auto &request : *request_batch) { + try { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::string msg = torchserve::Converter::VectorToStr(data_it->second); + + int num_prompt_tokens = 0; + + std::unique_ptr msgCStr( + new char[msg.size() + 1], [](char *ptr) { delete[] ptr; }); + + std::strcpy(msgCStr.get(), msg.c_str()); + + std::unique_ptr prompt_tokens(new int[msg.length() + 3]); + + encode(&tokenizer, msgCStr.get(), 1, 0, prompt_tokens.get(), + &num_prompt_tokens); + + std::vector tensor_vector; + for (int64_t i = 0; i < num_prompt_tokens; ++i) { + int token = prompt_tokens[i]; + torch::Tensor tensor = torch::tensor(token, torch::kInt64); + tensor_vector.push_back(tensor); + } + batch_ivalue.emplace_back(torch::stack(tensor_vector)); + + idx_to_req_id.second[idx++] = request.request_id; + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error:{}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + + return batch_ivalue; +} + +c10::IValue BabyLlamaHandler::Inference( + std::shared_ptr model, c10::IValue &inputs, + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + torch::InferenceMode guard; + auto batch_output_vector = c10::impl::GenericList(torch::TensorType::get()); + long batch_token_length = 0; + long start = + 0; // used to time our code, only initialized after first iteration + + try { + for (auto input : inputs.toTensorList()) { + std::vector tensor_vector; + tensor_vector.reserve(steps); + torch::Tensor tokens_list_tensor = input.get().toTensor(); + + int64_t num_elements = tokens_list_tensor.numel(); + + int64_t *data_ptr = tokens_list_tensor.data_ptr(); + + std::unique_ptr prompt_tokens(new int[num_elements]); + + for (int64_t i = 0; i < num_elements; ++i) { + prompt_tokens[i] = data_ptr[i]; + } + + // start the main loop + int next; // will store the next token in the sequence + int token = + prompt_tokens[0]; // kick off with the first token in the prompt + int pos = 0; // position in the sequence + while (pos < steps) { + // forward the transformer to get logits for the next token + float *logits = forward(&transformer, token, pos); + + // advance the state state machine + if (pos < num_elements - 1) { + // if we are still processing the input prompt, force the next prompt + // token + next = prompt_tokens[pos + 1]; + } else { + // otherwise sample the next token from the logits + next = sample(&sampler, logits); + } + pos++; + + torch::Tensor tensor = torch::tensor(next, torch::kLong); + tensor_vector.push_back(tensor); + + // data-dependent terminating condition: the BOS (=1) token delimits + // sequences + if (next == 1) { + break; + } + token = next; + + // init the timer here because the first iteration can be slower + if (start == 0) { + start = time_in_ms(); + } + } + batch_token_length = batch_token_length + pos - 1; + + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + + batch_output_vector.push_back(stacked_tensor); + } + + std::cout << "Total number of tokens generated: " << batch_token_length + << std::endl; + if (batch_token_length > 1) { + long end = time_in_ms(); + double token_per_sec = batch_token_length / (double)(end - start) * 1000; + std::cout << "Achieved tok per sec: " << token_per_sec << std::endl; + } + } catch (std::runtime_error &e) { + TS_LOG(ERROR, e.what()); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, "Failed to apply inference on input, c10 error:{}", e.msg()); + } catch (...) { + TS_LOG(ERROR, "Failed to run inference on this batch"); + } + std::cout << "WOOT?" << std::endl; + return batch_output_vector; +} + +void BabyLlamaHandler::Postprocess( + c10::IValue &outputs, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + auto data = outputs.toTensorList(); + for (const auto &kv : idx_to_req_id.second) { + try { + int64_t num_elements = data[kv.first].get().toTensor().numel(); + int64_t *data_ptr = data[kv.first].get().toTensor().data_ptr(); + int64_t token = 1; + std::string concatenated_string; + for (int64_t i = 0; i < num_elements; ++i) { + char *piece = decode(&tokenizer, token, data_ptr[i]); + std::string piece_string(piece); + token = data_ptr[i]; + concatenated_string += piece_string; + } + + std::cout << "Generated String: " << concatenated_string << std::endl; + + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + concatenated_string); + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} + +BabyLlamaHandler::~BabyLlamaHandler() noexcept { + free_sampler(&sampler); + free_tokenizer(&tokenizer); + free_transformer(&transformer); +} + +} // namespace llm + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::BaseHandler *allocatorBabyLlamaHandler() { + return new llm::BabyLlamaHandler(); +} + +void deleterBabyLlamaHandler(torchserve::BaseHandler *p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/cpp/src/examples/babyllama/baby_llama_handler.hh b/cpp/src/examples/babyllama/baby_llama_handler.hh new file mode 100644 index 0000000000..559d628769 --- /dev/null +++ b/cpp/src/examples/babyllama/baby_llama_handler.hh @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "src/backends/handler/base_handler.hh" + +namespace llm { +class BabyLlamaHandler : public torchserve::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + BabyLlamaHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~BabyLlamaHandler() noexcept; + + void initialize_context(); + + std::pair, std::shared_ptr> LoadModel( + std::shared_ptr& load_model_request) + override; + + c10::IValue Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + c10::IValue Inference( + std::shared_ptr model, c10::IValue& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + + void Postprocess( + c10::IValue& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; +}; +} // namespace llm diff --git a/cpp/src/examples/babyllama/llama2.c/LICENSE b/cpp/src/examples/babyllama/llama2.c/LICENSE new file mode 100644 index 0000000000..2ad12227f9 --- /dev/null +++ b/cpp/src/examples/babyllama/llama2.c/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Andrej + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/cpp/src/examples/babyllama/llama2.c/run.c b/cpp/src/examples/babyllama/llama2.c/run.c new file mode 100644 index 0000000000..cacd1414af --- /dev/null +++ b/cpp/src/examples/babyllama/llama2.c/run.c @@ -0,0 +1,863 @@ +/* Inference for Llama-2 Transformer model in pure C */ + +#include +#include +#include +#include +#include +#include +#include +#if defined _WIN32 + #include "win.h" +#else + #include + #include +#endif +// ---------------------------------------------------------------------------- +// Transformer model + +typedef struct { + int dim; // transformer dimension + int hidden_dim; // for ffn layers + int n_layers; // number of layers + int n_heads; // number of query heads + int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) + int vocab_size; // vocabulary size, usually 256 (byte-level) + int seq_len; // max sequence length +} Config; + +typedef struct { + // token embedding table + float* token_embedding_table; // (vocab_size, dim) + // weights for rmsnorms + float* rms_att_weight; // (layer, dim) rmsnorm weights + float* rms_ffn_weight; // (layer, dim) + // weights for matmuls. note dim == n_heads * head_size + float* wq; // (layer, dim, n_heads * head_size) + float* wk; // (layer, dim, n_kv_heads * head_size) + float* wv; // (layer, dim, n_kv_heads * head_size) + float* wo; // (layer, n_heads * head_size, dim) + // weights for ffn + float* w1; // (layer, hidden_dim, dim) + float* w2; // (layer, dim, hidden_dim) + float* w3; // (layer, hidden_dim, dim) + // final rmsnorm + float* rms_final_weight; // (dim,) + // (optional) classifier weights for the logits, on the last layer + float* wcls; +} TransformerWeights; + +typedef struct { + // current wave of activations + float *x; // activation at current time stamp (dim,) + float *xb; // same, but inside a residual branch (dim,) + float *xb2; // an additional buffer just for convenience (dim,) + float *hb; // buffer for hidden dimension in the ffn (hidden_dim,) + float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,) + float *q; // query (dim,) + float *k; // key (dim,) + float *v; // value (dim,) + float *att; // buffer for scores/attention values (n_heads, seq_len) + float *logits; // output logits + // kv cache + float* key_cache; // (layer, seq_len, dim) + float* value_cache; // (layer, seq_len, dim) +} RunState; + +typedef struct { + Config config; // the hyperparameters of the architecture (the blueprint) + TransformerWeights weights; // the weights of the model + RunState state; // buffers for the "wave" of activations in the forward pass + // some more state needed to properly clean up the memory mapping (sigh) + int fd; // file descriptor for memory mapping + float* data; // memory mapped data pointer + ssize_t file_size; // size of the checkpoint file in bytes +} Transformer; + +void malloc_run_state(RunState* s, Config* p) { + // we calloc instead of malloc to keep valgrind happy + int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; + s->x = (float*)calloc(p->dim, sizeof(float)); + s->xb = (float*)calloc(p->dim, sizeof(float)); + s->xb2 = (float*)calloc(p->dim, sizeof(float)); + s->hb = (float*)calloc(p->hidden_dim, sizeof(float)); + s->hb2 = (float*)calloc(p->hidden_dim, sizeof(float)); + s->q = (float*)calloc(p->dim, sizeof(float)); + s->k = (float*)calloc(kv_dim, sizeof(float)); + s->v = (float*)calloc(kv_dim, sizeof(float)); + s->att = (float*)calloc(p->n_heads * p->seq_len, sizeof(float)); + s->logits = (float*)calloc(p->vocab_size, sizeof(float)); + s->key_cache = (float*)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); + s->value_cache = (float*)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); + // ensure all mallocs went fine + if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q + || !s->k || !s->v || !s->att || !s->logits || !s->key_cache + || !s->value_cache) { + fprintf(stderr, "malloc failed!\n"); + exit(EXIT_FAILURE); + } +} + +void free_run_state(RunState* s) { + free(s->x); + free(s->xb); + free(s->xb2); + free(s->hb); + free(s->hb2); + free(s->q); + free(s->k); + free(s->v); + free(s->att); + free(s->logits); + free(s->key_cache); + free(s->value_cache); +} + +void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { + int head_size = p->dim / p->n_heads; + w->token_embedding_table = ptr; + ptr += p->vocab_size * p->dim; + w->rms_att_weight = ptr; + ptr += p->n_layers * p->dim; + w->wq = ptr; + ptr += p->n_layers * p->dim * (p->n_heads * head_size); + w->wk = ptr; + ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + w->wv = ptr; + ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + w->wo = ptr; + ptr += p->n_layers * (p->n_heads * head_size) * p->dim; + w->rms_ffn_weight = ptr; + ptr += p->n_layers * p->dim; + w->w1 = ptr; + ptr += p->n_layers * p->dim * p->hidden_dim; + w->w2 = ptr; + ptr += p->n_layers * p->hidden_dim * p->dim; + w->w3 = ptr; + ptr += p->n_layers * p->dim * p->hidden_dim; + w->rms_final_weight = ptr; + ptr += p->dim; + ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE) + ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE) + w->wcls = shared_weights ? w->token_embedding_table : ptr; +} + +void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, + int* fd, float** data, ssize_t* file_size) { + FILE *file = fopen(checkpoint, "rb"); + if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } + // read in the config header + if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } + // negative vocab size is hacky way of signaling unshared weights. bit yikes. + int shared_weights = config->vocab_size > 0 ? 1 : 0; + config->vocab_size = abs(config->vocab_size); + // figure out the file size + fseek(file, 0, SEEK_END); // move file pointer to end of file + *file_size = ftell(file); // get the file size, in bytes + fclose(file); + // memory map the Transformer weights into the data pointer + *fd = open(checkpoint, O_RDONLY); // open in read only mode + if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); } + *data = (float*)mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); + if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } + float* weights_ptr = *data + sizeof(Config)/sizeof(float); + memory_map_weights(weights, config, weights_ptr, shared_weights); +} + +void build_transformer(Transformer *t, char* checkpoint_path) { + // read in the Config and the Weights from the checkpoint + read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); + // allocate the RunState buffers + malloc_run_state(&t->state, &t->config); +} + +void free_transformer(Transformer* t) { + // close the memory mapping + if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } + if (t->fd != -1) { close(t->fd); } + // free the RunState buffers + free_run_state(&t->state); +} + +// ---------------------------------------------------------------------------- +// neural net blocks; the dynamics of the Transformer + +void rmsnorm(float* o, float* x, float* weight, int size) { + // calculate sum of squares + float ss = 0.0f; + for (int j = 0; j < size; j++) { + ss += x[j] * x[j]; + } + ss /= size; + ss += 1e-5f; + ss = 1.0f / sqrtf(ss); + // normalize and scale + for (int j = 0; j < size; j++) { + o[j] = weight[j] * (ss * x[j]); + } +} + +void softmax(float* x, int size) { + // find max value (for numerical stability) + float max_val = x[0]; + for (int i = 1; i < size; i++) { + if (x[i] > max_val) { + max_val = x[i]; + } + } + // exp and sum + float sum = 0.0f; + for (int i = 0; i < size; i++) { + x[i] = expf(x[i] - max_val); + sum += x[i]; + } + // normalize + for (int i = 0; i < size; i++) { + x[i] /= sum; + } +} + +void matmul(float* xout, float* x, float* w, int n, int d) { + // W (d,n) @ x (n,) -> xout (d,) + // by far the most amount of time is spent inside this little function + int i; + #pragma omp parallel for private(i) + for (i = 0; i < d; i++) { + float val = 0.0f; + for (int j = 0; j < n; j++) { + val += w[i * n + j] * x[j]; + } + xout[i] = val; + } +} + +float* forward(Transformer* transformer, int token, int pos) { + // a few convenience variables + Config* p = &transformer->config; + TransformerWeights* w = &transformer->weights; + RunState* s = &transformer->state; + float *x = s->x; + int dim = p->dim; + int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; + int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery + int hidden_dim = p->hidden_dim; + int head_size = dim / p->n_heads; + + // copy the token embedding into x + float* content_row = w->token_embedding_table + token * dim; + memcpy(x, content_row, dim*sizeof(*x)); + + // forward all the layers + for(int l = 0; l < p->n_layers; l++) { + + // attention rmsnorm + rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); + + // qkv matmuls for this position + matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); + matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); + matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); + + // RoPE relative positional encoding: complex-valued rotate q and k in each head + for (int i = 0; i < dim; i+=2) { + int head_dim = i % head_size; + float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); + float val = pos * freq; + float fcr = cosf(val); + float fci = sinf(val); + int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + for (int v = 0; v < rotn; v++) { + float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) + float v0 = vec[i]; + float v1 = vec[i+1]; + vec[i] = v0 * fcr - v1 * fci; + vec[i+1] = v0 * fci + v1 * fcr; + } + } + + // save key,value at this time step (pos) to our kv cache + int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience + float* key_cache_row = s->key_cache + loff + pos * kv_dim; + float* value_cache_row = s->value_cache + loff + pos * kv_dim; + memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); + memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); + + // multihead attention. iterate over all heads + int h; + #pragma omp parallel for private(h) + for (h = 0; h < p->n_heads; h++) { + // get the query vector for this head + float* q = s->q + h * head_size; + // attention scores for this head + float* att = s->att + h * p->seq_len; + // iterate over all timesteps, including the current one + for (int t = 0; t <= pos; t++) { + // get the key vector for this head and at this timestep + float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; + // calculate the attention score as the dot product of q and k + float score = 0.0f; + for (int i = 0; i < head_size; i++) { + score += q[i] * k[i]; + } + score /= sqrtf(head_size); + // save the score to the attention buffer + att[t] = score; + } + + // softmax the scores to get attention weights, from 0..pos inclusively + softmax(att, pos + 1); + + // weighted sum of the values, store back into xb + float* xb = s->xb + h * head_size; + memset(xb, 0, head_size * sizeof(float)); + for (int t = 0; t <= pos; t++) { + // get the value vector for this head and at this timestep + float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; + // get the attention weight for this timestep + float a = att[t]; + // accumulate the weighted value into xb + for (int i = 0; i < head_size; i++) { + xb[i] += a * v[i]; + } + } + } + + // final matmul to get the output of the attention + matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim); + + // residual connection back into x + for (int i = 0; i < dim; i++) { + x[i] += s->xb2[i]; + } + + // ffn rmsnorm + rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); + + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + // first calculate self.w1(x) and self.w3(x) + matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); + matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); + + // SwiGLU non-linearity + for (int i = 0; i < hidden_dim; i++) { + float val = s->hb[i]; + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + val *= (1.0f / (1.0f + expf(-val))); + // elementwise multiply with w3(x) + val *= s->hb2[i]; + s->hb[i] = val; + } + + // final matmul to get the output of the ffn + matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim); + + // residual connection + for (int i = 0; i < dim; i++) { + x[i] += s->xb[i]; + } + } + + // final rmsnorm + rmsnorm(x, x, w->rms_final_weight, dim); + + // classifier into logits + matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); + return s->logits; +} + +// ---------------------------------------------------------------------------- +// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens + +typedef struct { + char *str; + int id; +} TokenIndex; + +typedef struct { + char** vocab; + float* vocab_scores; + TokenIndex *sorted_vocab; + int vocab_size; + unsigned int max_token_length; + unsigned char byte_pieces[512]; // stores all single-byte strings +} Tokenizer; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { + // i should have written the vocab_size into the tokenizer file... sigh + t->vocab_size = vocab_size; + // malloc space to hold the scores and the strings + t->vocab = (char**)malloc(vocab_size * sizeof(char*)); + t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); + t->sorted_vocab = NULL; // initialized lazily + for (int i = 0; i < 256; i++) { + t->byte_pieces[i * 2] = (unsigned char)i; + t->byte_pieces[i * 2 + 1] = '\0'; + } + // read in the file + FILE *file = fopen(tokenizer_path, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } + if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + int len; + for (int i = 0; i < vocab_size; i++) { + if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} + if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i] = (char *)malloc(len + 1); + if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i][len] = '\0'; // add the string terminating token + } + fclose(file); +} + +void free_tokenizer(Tokenizer* t) { + for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } + free(t->vocab); + free(t->vocab_scores); + free(t->sorted_vocab); +} + +char* decode(Tokenizer* t, int prev_token, int token) { + char *piece = t->vocab[token]; + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) + if (prev_token == 1 && piece[0] == ' ') { piece++; } + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + // parse this and convert and return the actual byte + unsigned char byte_val; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { + piece = (char*)t->byte_pieces + byte_val * 2; + } + return piece; +} + +void safe_printf(char *piece) { + // piece might be a raw byte token, and we only want to print printable chars or whitespace + // because some of the other bytes can be various control codes, backspace, etc. + if (piece == NULL) { return; } + if (piece[0] == '\0') { return; } + if (piece[1] == '\0') { + unsigned char byte_val = piece[0]; + if (!(isprint(byte_val) || isspace(byte_val))) { + return; // bad byte, don't print it + } + } + printf("%s", piece); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // efficiently find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = { .str = str }; // acts as the key to search for + TokenIndex *res = (TokenIndex*)bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res != NULL ? res->id : -1; +} + +void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { + // encode the string text (input) into an upper-bound preallocated tokens[] array + // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) + if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } + + if (t->sorted_vocab == NULL) { + // lazily malloc and sort the vocabulary + t->sorted_vocab = (TokenIndex*)malloc(t->vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < t->vocab_size; i++) { + t->sorted_vocab[i].str = t->vocab[i]; + t->sorted_vocab[i].id = i; + } + qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); + } + + // create a temporary buffer that will store merge candidates of always two consecutive tokens + // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) + char* str_buffer = (char*)malloc((t->max_token_length*2 +1 +2) * sizeof(char)); + size_t str_len = 0; + + // start at 0 tokens + *n_tokens = 0; + + // add optional BOS (=1) token, if desired + if (bos) tokens[(*n_tokens)++] = 1; + + // add_dummy_prefix is true by default + // so prepend a dummy prefix token to the input string, but only if text != "" + // TODO: pretty sure this isn't correct in the general case but I don't have the + // energy to read more of the sentencepiece code to figure out what it's doing + if (text[0] != '\0') { + int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); + tokens[(*n_tokens)++] = dummy_prefix; + } + + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + + // process the raw (UTF-8) byte sequence of the input string + for (char *c = text; *c != '\0'; c++) { + + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest + // 0x80 is 10000000 + // in UTF-8, all continuation bytes start with "10" in first two bits + // so in English this is: "if this byte is not a continuation byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line + str_buffer[str_len] = '\0'; + + // while the next character is a continuation byte, continue appending + // but if there are too many of them, just stop to avoid overruning str_buffer size. + if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { + continue; + } + + // ok c+1 is not a continuation byte, so we've read in a full codepoint + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); + + if (id != -1) { + // we found this codepoint in vocab, add it as a token + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding: just encode each byte as a token + // +3 is here because the first 3 vocab elements are , , + // so the individual bytes only start at index 3 + for (int i=0; i < str_len; i++) { + tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; + } + } + str_len = 0; // protect against a sequence of stray UTF8 continuation bytes + } + + // merge the best consecutive pair each iteration, according the scores in vocab_scores + while (1) { + float best_score = -1e10; + int best_id = -1; + int best_idx = -1; + + for (int i=0; i < (*n_tokens-1); i++) { + // check if we can merge the pair (tokens[i], tokens[i+1]) + sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); + if (id != -1 && t->vocab_scores[id] > best_score) { + // this merge pair exists in vocab! record its score and position + best_score = t->vocab_scores[id]; + best_id = id; + best_idx = i; + } + } + + if (best_idx == -1) { + break; // we couldn't find any more pairs to merge, so we're done + } + + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + tokens[best_idx] = best_id; + // delete token at position best_idx+1, shift the entire sequence back 1 + for (int i = best_idx+1; i < (*n_tokens-1); i++) { + tokens[i] = tokens[i+1]; + } + (*n_tokens)--; // token length decreased + } + + // add optional EOS (=2) token, if desired + if (eos) tokens[(*n_tokens)++] = 2; + + free(str_buffer); +} + +// ---------------------------------------------------------------------------- +// The Sampler, which takes logits and returns a sampled token +// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling + +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling + +typedef struct { + int vocab_size; + ProbIndex* probindex; // buffer used in top-p sampling + float temperature; + float topp; + unsigned long long rng_state; +} Sampler; + +int sample_argmax(float* probabilities, int n) { + // return the index that has the highest probability + int max_i = 0; + float max_p = probabilities[0]; + for (int i = 1; i < n; i++) { + if (probabilities[i] > max_p) { + max_i = i; + max_p = probabilities[i]; + } + } + return max_i; +} + +int sample_mult(float* probabilities, int n, float coin) { + // sample index from probabilities (they must sum to 1!) + // coin is a random number in [0, 1), usually from random_f32() + float cdf = 0.0f; + for (int i = 0; i < n; i++) { + cdf += probabilities[i]; + if (coin < cdf) { + return i; + } + } + return n - 1; // in case of rounding errors +} + +int compare(const void* a, const void* b) { + ProbIndex* a_ = (ProbIndex*) a; + ProbIndex* b_ = (ProbIndex*) b; + if (a_->prob > b_->prob) return -1; + if (a_->prob < b_->prob) return 1; + return 0; +} + +int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability topp. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + // coin is a random number in [0, 1), usually from random_f32() + + int n0 = 0; + // quicksort indices in descending order of probabilities + // values smaller than (1 - topp) / (n - 1) cannot be part of the result + // so for efficiency we crop these out as candidates before sorting + const float cutoff = (1.0f - topp) / (n - 1); + for (int i = 0; i < n; i++) { + if (probabilities[i] >= cutoff) { + probindex[n0].index = i; + probindex[n0].prob = probabilities[i]; + n0++; + } + } + qsort(probindex, n0, sizeof(ProbIndex), compare); + + // truncate the list where cumulative probability exceeds topp + float cumulative_prob = 0.0f; + int last_idx = n0 - 1; // in case of rounding errors consider all elements + for (int i = 0; i < n0; i++) { + cumulative_prob += probindex[i].prob; + if (cumulative_prob > topp) { + last_idx = i; + break; // we've exceeded topp by including last_idx + } + } + + // sample from the truncated list + float r = coin * cumulative_prob; + float cdf = 0.0f; + for (int i = 0; i <= last_idx; i++) { + cdf += probindex[i].prob; + if (r < cdf) { + return probindex[i].index; + } + } + return probindex[last_idx].index; // in case of rounding errors +} + +void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { + sampler->vocab_size = vocab_size; + sampler->temperature = temperature; + sampler->topp = topp; + sampler->rng_state = rng_seed; + // buffer only used with nucleus sampling; may not need but it's ~small + sampler->probindex = (ProbIndex*)malloc(sampler->vocab_size * sizeof(ProbIndex)); +} + +void free_sampler(Sampler* sampler) { + free(sampler->probindex); +} + +unsigned int random_u32(unsigned long long *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} +float random_f32(unsigned long long *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + +int sample(Sampler* sampler, float* logits) { + // sample the token given the logits and some hyperparameters + int next; + if (sampler->temperature == 0.0f) { + // greedy argmax sampling: take the token with the highest probability + next = sample_argmax(logits, sampler->vocab_size); + } else { + // apply the temperature to the logits + for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } + // apply softmax to the logits to get the probabilities for next token + softmax(logits, sampler->vocab_size); + // flip a (float) coin (this is our source of entropy for sampling) + float coin = random_f32(&sampler->rng_state); + // we sample from this distribution to get the next token + if (sampler->topp <= 0 || sampler->topp >= 1) { + // simply sample from the predicted probability distribution + next = sample_mult(logits, sampler->vocab_size, coin); + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); + } + } + return next; +} + +// ---------------------------------------------------------------------------- +// utilities: time + +long time_in_ms() { + // return time in milliseconds, for benchmarking the model speed + struct timespec time; + clock_gettime(CLOCK_REALTIME, &time); + return time.tv_sec * 1000 + time.tv_nsec / 1000000; +} + +// ---------------------------------------------------------------------------- +// generation loop + +void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { + + // encode the (string) prompt into tokens sequence + int num_prompt_tokens = 0; + int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS + encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + if (num_prompt_tokens < 1) { + fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); + exit(EXIT_FAILURE); + } + + // start the main loop + long start = 0; // used to time our code, only initialized after first iteration + int next; // will store the next token in the sequence + int token = prompt_tokens[0]; // kick off with the first token in the prompt + int pos = 0; // position in the sequence + while (pos < steps) { + + // forward the transformer to get logits for the next token + float* logits = forward(transformer, token, pos); + + // advance the state state machine + if (pos < num_prompt_tokens - 1) { + // if we are still processing the input prompt, force the next prompt token + next = prompt_tokens[pos + 1]; + } else { + // otherwise sample the next token from the logits + next = sample(sampler, logits); + } + pos++; + + // data-dependent terminating condition: the BOS (=1) token delimits sequences + if (next == 1) { break; } + + // print the token as string, decode it with the Tokenizer object + char* piece = decode(tokenizer, token, next); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + token = next; + + // init the timer here because the first iteration can be slower + if (start == 0) { start = time_in_ms(); } + } + printf("\n"); + + // report achieved tok/s (pos-1 because the timer starts after first iteration) + if (pos > 1) { + long end = time_in_ms(); + fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); + } + + free(prompt_tokens); +} + + +// ---------------------------------------------------------------------------- +// CLI, include only if not testing +#ifndef TESTING + +void error_usage() { + fprintf(stderr, "Usage: run [options]\n"); + fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n"); + fprintf(stderr, "Options:\n"); + fprintf(stderr, " -t temperature in [0,inf], default 1.0\n"); + fprintf(stderr, " -p p value in top-p (nucleus) sampling in [0,1] default 0.9\n"); + fprintf(stderr, " -s random seed, default time(NULL)\n"); + fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); + fprintf(stderr, " -i input prompt\n"); + fprintf(stderr, " -z optional path to custom tokenizer\n"); + exit(EXIT_FAILURE); +} + + +int main(int argc, char *argv[]) { + + // default parameters + char *checkpoint_path = NULL; // e.g. out/model.bin + char *tokenizer_path = "tokenizer.bin"; + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower + int steps = 256; // number of steps to run for + char *prompt = ""; // prompt string + unsigned long long rng_seed = 0; // seed rng with time by default + + // poor man's C argparse so we can override the defaults above from the command line + if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } + for (int i = 2; i < argc; i+=2) { + // do some basic validation + if (i + 1 >= argc) { error_usage(); } // must have arg after flag + if (argv[i][0] != '-') { error_usage(); } // must start with dash + if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + // read in the args + if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); } + else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); } + else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } + else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } + else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } + else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } + else { error_usage(); } + } + + // parameter validation/overrides + if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL); + if (temperature < 0.0) temperature = 0.0; + if (topp < 0.0 || 1.0 < topp) topp = 0.9; + if (steps < 0) steps = 0; + + // build the Transformer via the model .bin file + Transformer transformer; + build_transformer(&transformer, checkpoint_path); + if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // ovrerride to ~max length + + // build the Tokenizer via the tokenizer .bin file + Tokenizer tokenizer; + build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size); + + // build the Sampler + Sampler sampler; + build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); + + // run! + generate(&transformer, &tokenizer, &sampler, prompt, steps); + + // memory and file handles cleanup + free_sampler(&sampler); + free_tokenizer(&tokenizer); + free_transformer(&transformer); + return 0; +} + +#endif diff --git a/cpp/src/examples/image_classifier/mnist/mnist_handler.cc b/cpp/src/examples/image_classifier/mnist/mnist_handler.cc index ca07cdd0cc..3fae5748a4 100644 --- a/cpp/src/examples/image_classifier/mnist/mnist_handler.cc +++ b/cpp/src/examples/image_classifier/mnist/mnist_handler.cc @@ -2,15 +2,16 @@ namespace mnist { void MnistHandler::Postprocess( - const torch::Tensor& data, + c10::IValue& data, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) { + auto data_tensor = data.toTensor(); for (const auto& kv : idx_to_req_id.second) { try { auto response = (*response_batch)[kv.second]; - response->SetResponse(200, "data_tpye", - torchserve::PayloadType::kDATA_TYPE_BYTES, - torch::pickle_save(torch::argmax(data[kv.first]))); + response->SetResponse( + 200, "data_tpye", torchserve::PayloadType::kDATA_TYPE_BYTES, + torch::pickle_save(torch::argmax(data_tensor[kv.first]))); } catch (const std::runtime_error& e) { LOG(ERROR) << "Failed to load tensor for request id:" << kv.second << ", error: " << e.what(); @@ -34,14 +35,14 @@ void MnistHandler::Postprocess( #if defined(__linux__) || defined(__APPLE__) extern "C" { -torchserve::torchscripted::BaseHandler* allocatorMnistHandler() { +torchserve::BaseHandler* allocatorMnistHandler() { return new mnist::MnistHandler(); } -void deleterMnistHandler(torchserve::torchscripted::BaseHandler* p) { +void deleterMnistHandler(torchserve::BaseHandler* p) { if (p != nullptr) { delete static_cast(p); } } } -#endif \ No newline at end of file +#endif diff --git a/cpp/src/examples/image_classifier/mnist/mnist_handler.hh b/cpp/src/examples/image_classifier/mnist/mnist_handler.hh index 4b9d9ff807..f54aba2171 100644 --- a/cpp/src/examples/image_classifier/mnist/mnist_handler.hh +++ b/cpp/src/examples/image_classifier/mnist/mnist_handler.hh @@ -1,10 +1,9 @@ -#ifndef MNIST_HANDLER_HH_ -#define MNIST_HANDLER_HH_ +#pragma once -#include "src/backends/torch_scripted/handler/base_handler.hh" +#include "src/backends/handler/torch_scripted_handler.hh" namespace mnist { -class MnistHandler : public torchserve::torchscripted::BaseHandler { +class MnistHandler : public torchserve::TorchScriptHandler { public: // NOLINTBEGIN(bugprone-exception-escape) MnistHandler() = default; @@ -12,10 +11,9 @@ class MnistHandler : public torchserve::torchscripted::BaseHandler { ~MnistHandler() override = default; void Postprocess( - const torch::Tensor& data, + c10::IValue& data, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) override; }; } // namespace mnist -#endif // MNIST_HANDLER_HH_ \ No newline at end of file diff --git a/cpp/src/utils/model_archive.cc b/cpp/src/utils/model_archive.cc index ed46d0cad1..b29a099541 100644 --- a/cpp/src/utils/model_archive.cc +++ b/cpp/src/utils/model_archive.cc @@ -1,5 +1,7 @@ #include "src/utils/model_archive.hh" +#include + namespace torchserve { bool Manifest::Initialize(const std::string& manifest_json_file_path) { try { @@ -16,14 +18,10 @@ bool Manifest::Initialize(const std::string& manifest_json_file_path) { } SetValue(model, torchserve::Manifest::kModel_Handler, model_.handler, true); - if (!SetValue(model, torchserve::Manifest::kModel_SerializedFile, - model_.serialized_file, false) && - !SetValue(model, torchserve::Manifest::kModel_ModelFile, - model_.model_file, false)) { - TS_LOGF(ERROR, "Item: {} and item : {} not defined in {}", - torchserve::Manifest::kModel_SerializedFile, - torchserve::Manifest::kModel_ModelFile, manifest_json_file_path); - } + SetValue(model, torchserve::Manifest::kModel_SerializedFile, + model_.serialized_file, false); + SetValue(model, torchserve::Manifest::kModel_ModelFile, model_.model_file, + false); SetValue(model, torchserve::Manifest::kModel_ModelName, model_.model_name, false); @@ -47,14 +45,14 @@ bool Manifest::Initialize(const std::string& manifest_json_file_path) { SetValue(val, torchserve::Manifest::kArchiverVersion, archiver_version_, false); SetValue(val, torchserve::Manifest::kRuntimeType, runtime_type_, false); + return true; } catch (const std::invalid_argument& e) { TS_LOGF(ERROR, "Failed to init Manifest from: {}, error: ", manifest_json_file_path, e.what()); - return false; } catch (...) { TS_LOGF(ERROR, "Failed to init Manifest from: {}", manifest_json_file_path); } - return true; + return false; } bool Manifest::SetValue(const folly::dynamic& source, const std::string& key, @@ -70,4 +68,4 @@ bool Manifest::SetValue(const folly::dynamic& source, const std::string& key, } return true; } -} // namespace torchserve \ No newline at end of file +} // namespace torchserve diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 70017424e0..838b250bd9 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -9,11 +9,10 @@ FetchContent_MakeAvailable(googletest) enable_testing() set(TEST_BINARY ${CMAKE_PROJECT_NAME}_test) -set(TS_TEST_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/test/") file(GLOB_RECURSE TEST_SOURCES LIST_DIRECTORIES false *.cc *.hh) add_executable(${TEST_BINARY} ${TEST_SOURCES}) -target_link_libraries(${TEST_BINARY} gtest_main gmock_main ts_backends_torch_scripted ts_backends_protocol ts_utils) +target_link_libraries(${TEST_BINARY} gtest_main gmock_main ts_backends_core ts_backends_protocol ts_utils ${TORCH_LIBRARIES}) include(GoogleTest) -gtest_discover_tests(${TEST_BINARY}) \ No newline at end of file +gtest_discover_tests(${TEST_BINARY}) diff --git a/cpp/test/backends/otf_protocol_and_handler_test.cc b/cpp/test/backends/otf_protocol_and_handler_test.cc index e4be81cc74..cc0d7960ec 100644 --- a/cpp/test/backends/otf_protocol_and_handler_test.cc +++ b/cpp/test/backends/otf_protocol_and_handler_test.cc @@ -1,8 +1,8 @@ #include #include "protocol/mock_socket.hh" +#include "src/backends/core/backend.hh" #include "src/backends/process/model_worker.hh" -#include "src/backends/torch_scripted/torch_scripted_backend.hh" #include "src/utils/metrics/registry.hh" namespace torchserve { @@ -68,7 +68,7 @@ TEST(BackendIntegTest, TestOTFProtocolAndHandler) { ASSERT_EQ(load_model_request->gpu_id, -1); // initialize backend - auto backend = std::make_shared(); + auto backend = std::make_shared(); MetricsRegistry::Initialize("test/resources/metrics/default_config.yaml", MetricsContext::BACKEND); backend->Initialize("test/resources/torchscript_model/mnist/base_handler"); diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc deleted file mode 100644 index b3099d1a2a..0000000000 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ /dev/null @@ -1,110 +0,0 @@ -#include "src/backends/torch_scripted/torch_scripted_backend.hh" - -#include -#include - -#include -#include - -#include "src/utils/message.hh" -#include "src/utils/metrics/registry.hh" - -namespace torchserve { -class TorchScriptedBackendTest : public ::testing::Test { - protected: - void SetUp() override { - backend_ = std::make_shared(); - } - - void LoadPredict( - std::shared_ptr load_model_request, - const std::string& model_dir, - const std::string& inference_input_file_path, - const std::string& inference_request_id_prefix, - int inference_expect_code) { - MetricsRegistry::Initialize("test/resources/metrics/default_config.yaml", - MetricsContext::BACKEND); - backend_->Initialize(model_dir); - auto result = backend_->LoadModel(std::move(load_model_request)); - ASSERT_EQ(result->code, 200); - - std::ifstream input(inference_input_file_path, - std::ios::in | std::ios::binary); - std::vector image((std::istreambuf_iterator(input)), - (std::istreambuf_iterator())); - input.close(); - - auto inference_request_batch = - std::make_shared(); - for (uint8_t i = 0; i < batch_size_; i++) { - torchserve::InferenceRequest inference_request; - inference_request.request_id = - fmt::format("{}_{}", inference_request_id_prefix, i); - inference_request - .headers[torchserve::PayloadType::kHEADER_NAME_DATA_TYPE] = - torchserve::PayloadType::kDATA_TYPE_BYTES; - inference_request - .parameters[torchserve::PayloadType::kPARAMETER_NAME_DATA] = image; - - (*inference_request_batch).emplace_back(inference_request); - } - - auto inference_response_batch = - backend_->GetModelInstance()->Predict(inference_request_batch); - for (const auto& kv : *inference_response_batch) { - ASSERT_EQ(kv.second->code, inference_expect_code); - } - }; - - uint8_t batch_size_ = 2; - std::shared_ptr backend_; -}; - -TEST_F(TorchScriptedBackendTest, TestLoadPredictBaseHandler) { - this->LoadPredict(std::make_shared( - "test/resources/torchscript_model/mnist/mnist_handler", - "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/torchscript_model/mnist/base_handler", - "test/resources/torchscript_model/mnist/0_png.pt", - "mnist_ts", 200); -} - -TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) { - this->LoadPredict(std::make_shared( - "test/resources/torchscript_model/mnist/mnist_handler", - "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/torchscript_model/mnist/mnist_handler", - "test/resources/torchscript_model/mnist/0_png.pt", - "mnist_ts", 200); -} - -TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) { - auto result = backend_->Initialize("test/resources/torchscript_model/mnist"); - ASSERT_EQ(result, false); -} - -TEST_F(TorchScriptedBackendTest, TestBackendInitWrongHandler) { - auto result = backend_->Initialize( - "test/resources/torchscript_model/mnist/wrong_handler"); - ASSERT_EQ(result, false); -} - -TEST_F(TorchScriptedBackendTest, TestLoadModelFailure) { - backend_->Initialize("test/resources/torchscript_model/mnist/wrong_model"); - auto result = - backend_->LoadModel(std::make_shared( - "test/resources/torchscript_model/mnist/wrong_model", - "mnist_scripted_v2", -1, "", "", 1, false)); - ASSERT_EQ(result->code, 500); -} - -TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandlerFailure) { - this->LoadPredict(std::make_shared( - "test/resources/torchscript_model/mnist/mnist_handler", - "mnist_scripted_v2", -1, "", "", 1, false), - "test/resources/torchscript_model/mnist/mnist_handler", - "test/resources/torchscript_model/mnist/0.png", "mnist_ts", - 500); -} - -} // namespace torchserve diff --git a/cpp/test/examples/examples_test.cc b/cpp/test/examples/examples_test.cc new file mode 100644 index 0000000000..f3a1d4b231 --- /dev/null +++ b/cpp/test/examples/examples_test.cc @@ -0,0 +1,10 @@ +#include "test/utils/common.hh" + +TEST_F(ModelPredictTest, TestLoadPredictBabyLlamaHandler) { + this->LoadPredict( + std::make_shared( + "test/resources/torchscript_model/babyllama/babyllama_handler", "llm", + -1, "", "", 1, false), + "test/resources/torchscript_model/babyllama/babyllama_handler", + "test/resources/torchscript_model/babyllama/prompt.txt", "llm_ts", 200); +} diff --git a/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..9ff70f75f6 --- /dev/null +++ b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,10 @@ +{ + "createdOn": "28/07/2020 06:32:08", + "runtime": "LSP", + "model": { + "modelName": "babyllama", + "handler": "libbabyllama_handler:BabyLlamaHandler", + "modelVersion": "2.0" + }, + "archiverVersion": "0.2.0" +} diff --git a/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.json b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.json new file mode 100644 index 0000000000..2030358b84 --- /dev/null +++ b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.json @@ -0,0 +1,5 @@ +{ +"checkpoint_path" : "/home/ubuntu/serve/cpp/stories15M.bin", +"tokenizer_path" : "/home/ubuntu/serve/cpp/src/examples/babyllama/tokenizer.bin" +} + diff --git a/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.properties b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.properties new file mode 100644 index 0000000000..87f86fc6af --- /dev/null +++ b/cpp/test/resources/torchscript_model/babyllama/babyllama_handler/config.properties @@ -0,0 +1 @@ +default_response_timeout=300000 diff --git a/cpp/test/resources/torchscript_model/babyllama/prompt.txt b/cpp/test/resources/torchscript_model/babyllama/prompt.txt new file mode 100644 index 0000000000..74b56be151 --- /dev/null +++ b/cpp/test/resources/torchscript_model/babyllama/prompt.txt @@ -0,0 +1 @@ +Hello my name is diff --git a/cpp/test/resources/torchscript_model/mnist/base_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/mnist/base_handler/MAR-INF/MANIFEST.json index bf8123e3a8..6864aa50fd 100644 --- a/cpp/test/resources/torchscript_model/mnist/base_handler/MAR-INF/MANIFEST.json +++ b/cpp/test/resources/torchscript_model/mnist/base_handler/MAR-INF/MANIFEST.json @@ -4,7 +4,7 @@ "model": { "modelName": "mnist_scripted_v2", "serializedFile": "mnist_script.pt", - "handler": "BaseHandler", + "handler": "TorchScriptHandler", "modelVersion": "2.0" }, "archiverVersion": "0.2.0" diff --git a/cpp/test/resources/torchscript_model/mnist/wrong_model/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/mnist/wrong_model/MAR-INF/MANIFEST.json index bf8123e3a8..6864aa50fd 100644 --- a/cpp/test/resources/torchscript_model/mnist/wrong_model/MAR-INF/MANIFEST.json +++ b/cpp/test/resources/torchscript_model/mnist/wrong_model/MAR-INF/MANIFEST.json @@ -4,7 +4,7 @@ "model": { "modelName": "mnist_scripted_v2", "serializedFile": "mnist_script.pt", - "handler": "BaseHandler", + "handler": "TorchScriptHandler", "modelVersion": "2.0" }, "archiverVersion": "0.2.0" diff --git a/cpp/test/torch_scripted/torch_scripted_test.cc b/cpp/test/torch_scripted/torch_scripted_test.cc new file mode 100644 index 0000000000..cc6806d5c7 --- /dev/null +++ b/cpp/test/torch_scripted/torch_scripted_test.cc @@ -0,0 +1,55 @@ +#include +#include + +#include +#include + +#include "src/utils/message.hh" +#include "test/utils/common.hh" + +TEST_F(ModelPredictTest, TestLoadPredictBaseHandler) { + this->LoadPredict(std::make_shared( + "test/resources/torchscript_model/mnist/mnist_handler", + "mnist_scripted_v2", -1, "", "", 1, false), + "test/resources/torchscript_model/mnist/base_handler", + "test/resources/torchscript_model/mnist/0_png.pt", + "mnist_ts", 200); +} + +TEST_F(ModelPredictTest, TestLoadPredictMnistHandler) { + this->LoadPredict(std::make_shared( + "test/resources/torchscript_model/mnist/mnist_handler", + "mnist_scripted_v2", -1, "", "", 1, false), + "test/resources/torchscript_model/mnist/mnist_handler", + "test/resources/torchscript_model/mnist/0_png.pt", + "mnist_ts", 200); +} + +TEST_F(ModelPredictTest, TestBackendInitWrongModelDir) { + auto result = backend_->Initialize("test/resources/torchscript_model/mnist"); + ASSERT_EQ(result, false); +} + +TEST_F(ModelPredictTest, TestBackendInitWrongHandler) { + auto result = backend_->Initialize( + "test/resources/torchscript_model/mnist/wrong_handler"); + ASSERT_EQ(result, false); +} + +TEST_F(ModelPredictTest, TestLoadModelFailure) { + backend_->Initialize("test/resources/torchscript_model/mnist/wrong_model"); + auto result = + backend_->LoadModel(std::make_shared( + "test/resources/torchscript_model/mnist/wrong_model", + "mnist_scripted_v2", -1, "", "", 1, false)); + ASSERT_EQ(result->code, 500); +} + +TEST_F(ModelPredictTest, TestLoadPredictMnistHandlerFailure) { + this->LoadPredict(std::make_shared( + "test/resources/torchscript_model/mnist/mnist_handler", + "mnist_scripted_v2", -1, "", "", 1, false), + "test/resources/torchscript_model/mnist/mnist_handler", + "test/resources/torchscript_model/mnist/0.png", "mnist_ts", + 500); +} diff --git a/cpp/test/utils/common.hh b/cpp/test/utils/common.hh new file mode 100644 index 0000000000..27d548503a --- /dev/null +++ b/cpp/test/utils/common.hh @@ -0,0 +1,54 @@ +#pragma once +#include + +#include "src/backends/core/backend.hh" +#include "src/utils/metrics/registry.hh" + +class ModelPredictTest : public ::testing::Test { + protected: + void SetUp() override { backend_ = std::make_shared(); } + + void LoadPredict( + std::shared_ptr load_model_request, + const std::string& model_dir, + const std::string& inference_input_file_path, + const std::string& inference_request_id_prefix, + int inference_expect_code) { + torchserve::MetricsRegistry::Initialize( + "test/resources/metrics/default_config.yaml", + torchserve::MetricsContext::BACKEND); + backend_->Initialize(model_dir); + auto result = backend_->LoadModel(std::move(load_model_request)); + ASSERT_EQ(result->code, 200); + + std::ifstream input(inference_input_file_path, + std::ios::in | std::ios::binary); + std::vector image((std::istreambuf_iterator(input)), + (std::istreambuf_iterator())); + input.close(); + + auto inference_request_batch = + std::make_shared(); + for (uint8_t i = 0; i < batch_size_; i++) { + torchserve::InferenceRequest inference_request; + inference_request.request_id = + fmt::format("{}_{}", inference_request_id_prefix, i); + inference_request + .headers[torchserve::PayloadType::kHEADER_NAME_DATA_TYPE] = + torchserve::PayloadType::kDATA_TYPE_BYTES; + inference_request + .parameters[torchserve::PayloadType::kPARAMETER_NAME_DATA] = image; + + (*inference_request_batch).emplace_back(inference_request); + } + + auto inference_response_batch = + backend_->GetModelInstance()->Predict(inference_request_batch); + for (const auto& kv : *inference_response_batch) { + ASSERT_EQ(kv.second->code, inference_expect_code); + } + }; + + uint8_t batch_size_ = 2; + std::shared_ptr backend_; +}; diff --git a/cpp/test/utils/model_archiver_test.cc b/cpp/test/utils/model_archiver_test.cc index f62df6779d..ea3f5082a2 100644 --- a/cpp/test/utils/model_archiver_test.cc +++ b/cpp/test/utils/model_archiver_test.cc @@ -13,7 +13,7 @@ TEST(ManifestTest, TestInitialize) { ASSERT_EQ(manifest.GetRuntimeType(), "LSP"); ASSERT_EQ(manifest.GetModel().model_name, "mnist_scripted_v2"); ASSERT_EQ(manifest.GetModel().serialized_file, "mnist_script.pt"); - ASSERT_EQ(manifest.GetModel().handler, "BaseHandler"); + ASSERT_EQ(manifest.GetModel().handler, "TorchScriptHandler"); ASSERT_EQ(manifest.GetModel().model_version, "2.0"); } } // namespace torchserve diff --git a/examples/cpp/babyllama/README.md b/examples/cpp/babyllama/README.md new file mode 100644 index 0000000000..cba4df5cd5 --- /dev/null +++ b/examples/cpp/babyllama/README.md @@ -0,0 +1,87 @@ +This example is adapted from https://github.com/karpathy/llama2.c. The handler C++ source code for this examples can be found [here](../../../cpp/src/examples/babyllama/). + +### Setup +1. Follow the instructions in [README.md](../../../cpp/README.md) to build the TorchServe C++ backend. + +``` +cd serve/cpp +./builld.sh +``` + +2. Download the model and tokenizer using the following command + +```bash +cd ~/serve/examples/cpp/babyllama +wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin +wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin +``` + +4. Create a [config.json](config.json) with the path of the downloaded model and tokenizer: + +```bash +echo '{ +"checkpoint_path" : "/home/ubuntu/serve/examples/cpp/babyllama/stories15M.bin", +"tokenizer_path" : "/home/ubuntu/serve/examples/cpp/babyllama/tokenizer.bin" +}' > config.json +``` + +5. Copy handle .so file + +While building the C++ backend the `libbabyllama_handler.so` file is generated in the [babyllama_handler](../../../cpp/test/resources/examples/babyllama/babyllama_handler) folder. + +```bash +cp ../../../cpp/test/resources/examples/babyllama/babyllama_handler/libbabyllama_handler.so ./ +``` + +### Generate MAR file + +Now lets generate the mar file + +```bash +torch-model-archiver --model-name llm --version 1.0 --handler libbabyllama_handler:BabyLlamaHandler --runtime LSP --extra-files config.json +``` + +Create model store directory and move the mar file + +``` +mkdir model_store +mv llm.mar model_store/ +``` + +### Inference + +Start torchserve using the following command + +``` +torchserve --ncs --model-store model_store/ +``` + +Register the model using the following command + +``` +curl -v -X POST "http://localhost:8081/models?initial_workers=1&url=llm.mar&batch_size=2&max_batch_delay=5000" +``` + +Infer the model using the following command + +``` +curl http://localhost:8080/predictions/llm -T prompt1.txt +``` + +This example supports batching. To run batch prediction, run the following command + +``` +curl http://localhost:8080/predictions/llm -T prompt1.txt & curl http://localhost:8080/predictions/llm -T prompt2.txt & +``` + +Sample Response + +``` +Hello my name is Daisy. Daisy is three years old. She loves to play with her toys. +One day, Daisy's mommy said, "Daisy, it's time to go to the store." Daisy was so excited! She ran to the store with her mommy. +At the store, Daisy saw a big, red balloon. She wanted it so badly! She asked her mommy, "Can I have the balloon, please?" +Mommy said, "No, Daisy. We don't have enough money for that balloon." +Daisy was sad. She wanted the balloon so much. She started to cry. +Mommy said, "Daisy, don't cry. We can get the balloon. We can buy it and take it home." +Daisy smiled. She was so happy. She hugged her mommy and said, "Thank you, mommy!" +``` diff --git a/examples/cpp/babyllama/prompt1.txt b/examples/cpp/babyllama/prompt1.txt new file mode 100644 index 0000000000..baa5a1abbf --- /dev/null +++ b/examples/cpp/babyllama/prompt1.txt @@ -0,0 +1 @@ +Hello my name is Dan diff --git a/examples/cpp/babyllama/prompt2.txt b/examples/cpp/babyllama/prompt2.txt new file mode 100644 index 0000000000..99568648e9 --- /dev/null +++ b/examples/cpp/babyllama/prompt2.txt @@ -0,0 +1 @@ +Hello my name is Daisy diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 7ca49ddaf0..865e6400fe 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1164,3 +1164,9 @@ compilable nightlies torchexportaotcompile autotune +babyllama +libbabyllama +BabyLLama +BabyLlamaHandler +CMakeLists +TorchScriptHandler