diff --git a/include/tl2cgen/detail/compiler/codegen/codegen.h b/include/tl2cgen/detail/compiler/codegen/codegen.h index 7fb62b5..4c89c28 100644 --- a/include/tl2cgen/detail/compiler/codegen/codegen.h +++ b/include/tl2cgen/detail/compiler/codegen/codegen.h @@ -7,16 +7,27 @@ #ifndef TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_ #define TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_ +#include #include #include #include #include -namespace tl2cgen::compiler::detail { +/* Forward declarations */ +namespace treelite { -namespace ast { +class Model; + +} // namespace treelite + +namespace tl2cgen::compiler { + +struct CompilerParam; + +} // namespace tl2cgen::compiler + +namespace tl2cgen::compiler::detail::ast { -// Forward declarations class ASTNode; class MainNode; class FunctionNode; @@ -26,13 +37,16 @@ class TranslationUnitNode; class QuantizerNode; class ModelMeta; -} // namespace ast +} // namespace tl2cgen::compiler::detail::ast -namespace codegen { +namespace tl2cgen::compiler::detail::codegen { class CodeCollection; // forward declaration void GenerateCodeFromAST(ast::ASTNode const* node, CodeCollection& gencode); +void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const& collection); +void WriteBuildRecipeToDisk(std::filesystem::path const& dirpath, + std::string const& native_lib_name, CodeCollection const& collection); // Codegen implementation for each AST node type void HandleMainNode(ast::MainNode const* node, CodeCollection& gencode); @@ -70,6 +84,9 @@ class SourceFile { void ChangeIndent(int n_tabs_delta); // Add or remove indent void PushFragment(std::string content); friend std::ostream& operator<<(std::ostream&, CodeCollection const&); + friend void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const&); + friend void WriteBuildRecipeToDisk( + std::filesystem::path const&, std::string const&, CodeCollection const&); friend class CodeCollection; }; @@ -88,11 +105,12 @@ class CodeCollection { void PushFragment(std::string content); friend std::ostream& operator<<(std::ostream&, CodeCollection const&); + friend void WriteCodeToDisk(std::filesystem::path const&, CodeCollection const&); + friend void WriteBuildRecipeToDisk( + std::filesystem::path const&, std::string const&, CodeCollection const&); }; std::ostream& operator<<(std::ostream& os, CodeCollection const& collection); -} // namespace codegen - -} // namespace tl2cgen::compiler::detail +} // namespace tl2cgen::compiler::detail::codegen #endif // TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_ diff --git a/include/tl2cgen/detail/filesystem.h b/include/tl2cgen/detail/filesystem.h index ac4ad35..a394b03 100644 --- a/include/tl2cgen/detail/filesystem.h +++ b/include/tl2cgen/detail/filesystem.h @@ -19,21 +19,6 @@ namespace tl2cgen::detail::filesystem { */ void CreateDirectoryIfNotExist(std::filesystem::path const& dirpath); -/*! - * \brief Write a sequence of strings to a text file, with newline character (\n) inserted between - * strings. This function is suitable for creating multi-line text files. - * \param path Path to text file - * \param content A sequence of strings to be written. - */ -void WriteToFile(std::filesystem::path const& path, std::string const& content); - -/*! - * \brief Write a sequence of bytes to a text file - * \param path Path to text file - * \param content A sequence of bytes to be written. - */ -void WriteToFile(std::filesystem::path const& path, std::vector const& content); - } // namespace tl2cgen::detail::filesystem #endif // TL2CGEN_DETAIL_FILESYSTEM_H_ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index cd07f47..40a74fd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -87,7 +87,6 @@ int TL2cgenGenerateCode( std::filesystem::path dirpath_ = std::filesystem::weakly_canonical(std::filesystem::u8path(std::string(dirpath))); - detail::filesystem::CreateDirectoryIfNotExist(dirpath_); /* Compile model */ auto param = compiler::CompilerParam::ParseFromJSON(compiler_params_json_str); diff --git a/src/compiler/codegen/codegen.cc b/src/compiler/codegen/codegen.cc index a89543e..19dd171 100644 --- a/src/compiler/codegen/codegen.cc +++ b/src/compiler/codegen/codegen.cc @@ -5,11 +5,15 @@ * \author Hyunsu Cho */ +#include +#include #include #include #include +#include #include +#include #include #include #include @@ -56,6 +60,49 @@ void GenerateCodeFromAST(ast::ASTNode const* node, CodeCollection& gencode) { } } +void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const& collection) { + namespace fs = tl2cgen::detail::filesystem; + for (auto const& [file_name, source_file] : collection.sources_) { + std::ofstream of(dirpath / file_name); + for (auto const& fragment : source_file.fragments_) { + of << IndentMultiLineString(fragment.content_, fragment.indent_) << "\n"; + } + of << "\n"; + } +} + +void WriteBuildRecipeToDisk(std::filesystem::path const& dirpath, + std::string const& native_lib_name, CodeCollection const& collection) { + std::ofstream ofs(dirpath / "recipe.json"); + rapidjson::OStreamWrapper ofs_wrapped(ofs); + rapidjson::PrettyWriter writer(ofs_wrapped); + writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray); + + writer.StartObject(); + writer.Key("target"); + writer.String(native_lib_name); + writer.Key("sources"); + writer.StartArray(); + for (auto const& [file_name, source_file] : collection.sources_) { + if (file_name.compare(file_name.length() - 2, 2, ".c") == 0) { + std::size_t line_count = 0; + for (auto const& fragment : source_file.fragments_) { + line_count += std::count(fragment.content_.begin(), fragment.content_.end(), '\n'); + } + writer.StartObject(); + writer.Key("name"); + std::string name = file_name.substr(0, file_name.length() - 2); + writer.String(name); + writer.Key("length"); + writer.Uint64(line_count); + writer.EndObject(); + } + } + writer.EndArray(); + writer.EndObject(); + ofs << "\n"; // Add newline at the end, for convention's sake +} + std::string GetThresholdCType(ast::ASTNode const* node) { return GetThresholdCType(*node->meta_); } diff --git a/src/compiler/codegen/postprocessor.cc b/src/compiler/codegen/postprocessor.cc index d99aaf5..01d5b56 100644 --- a/src/compiler/codegen/postprocessor.cc +++ b/src/compiler/codegen/postprocessor.cc @@ -192,10 +192,10 @@ static void postprocess_impl({threshold_type}* target_result, int num_class) {{ for (int k = 0; k < num_class; ++k) {{ t = {exp}(target_result[k] - max_margin); norm_const += t; - pred[k] = t; + target_result[k] = t; }} for (int k = 0; k < num_class; ++k) {{ - pred[k] /= ({threshold_type})norm_const; + target_result[k] /= ({threshold_type})norm_const; }} }} @@ -207,7 +207,7 @@ void postprocess({threshold_type}* result) {{ auto const max_num_class = *std::max_element(model_meta.num_class_.begin(), model_meta.num_class_.end()); for (std::int32_t target_id = 0; target_id < model_meta.num_target_; ++target_id) { - fmt::print(oss, " postprocess_impl(result[{offset}], {num_class});\n", + fmt::print(oss, " postprocess_impl(&result[{offset}], {num_class});\n", "offset"_a = target_id * max_num_class, "num_class"_a = model_meta.num_class_[target_id]); } oss << "}\n"; @@ -239,7 +239,7 @@ void postprocess({threshold_type}* result) {{ auto const max_num_class = *std::max_element(model_meta.num_class_.begin(), model_meta.num_class_.end()); for (std::int32_t target_id = 0; target_id < model_meta.num_target_; ++target_id) { - fmt::print(oss, " postprocess_impl(result[{offset}], {num_class});\n", + fmt::print(oss, " postprocess_impl(&result[{offset}], {num_class});\n", "offset"_a = target_id * max_num_class, "num_class"_a = model_meta.num_class_[target_id]); } oss << "}\n"; diff --git a/src/compiler/compiler.cc b/src/compiler/compiler.cc index 87a8979..6a4c0b7 100644 --- a/src/compiler/compiler.cc +++ b/src/compiler/compiler.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -51,11 +52,15 @@ namespace tl2cgen::compiler { void CompileModel(treelite::Model const& model, CompilerParam const& param, std::filesystem::path const& dirpath) { + tl2cgen::detail::filesystem::CreateDirectoryIfNotExist(dirpath); auto builder = LowerToAST(model, param); /* Generate C code */ detail::codegen::CodeCollection gencode; detail::codegen::GenerateCodeFromAST(builder.GetRootNode(), gencode); - TL2CGEN_LOG(INFO) << "\n" << gencode; + // Write C code to disk + detail::codegen::WriteCodeToDisk(dirpath, gencode); + // Write recipe.json + detail::codegen::WriteBuildRecipeToDisk(dirpath, param.native_lib_name, gencode); } std::string DumpAST(treelite::Model const& model, CompilerParam const& param) { diff --git a/src/compiler/postprocessor.cc b/src/compiler/postprocessor.cc deleted file mode 100644 index ea776a9..0000000 --- a/src/compiler/postprocessor.cc +++ /dev/null @@ -1,89 +0,0 @@ -/*! - * Copyright (c) 2024 by Contributors - * \file postprocessor.cc - * \brief Library of transform functions to convert margins into predictions - * \author Hyunsu Cho - */ - -#include -#include -#include -#include - -#include -#include - -#define TL2CGEN_POSTPROCESSOR_FUNC(name) \ - { #name, &(name) } - -namespace { - -using Model = treelite::Model; -using PostprocessorGenerator = std::string (*)(Model const&); - -/* Boilerplate */ -#define TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ - std::string FUNC_NAME(const treelite::Model& model) { \ - return tl2cgen::compiler::detail::templates::postprocessor::FUNC_NAME(model); \ - } - -/* - * See https://treelite.readthedocs.io/en/latest/knobs/postprocessor.html for the description of - * each postprocessor function. - */ - -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(identity) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(signed_square) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(hinge) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(sigmoid) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(exponential) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(exponential_standard_ratio) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(softmax) -TL2CGEN_POSTPROCESSOR_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova) - -std::unordered_map const const postprocessor_db = { - TL2CGEN_POSTPROCESSOR_FUNC(identity), TL2CGEN_POSTPROCESSOR_FUNC(signed_square), - TL2CGEN_POSTPROCESSOR_FUNC(hinge), TL2CGEN_POSTPROCESSOR_FUNC(sigmoid), - TL2CGEN_POSTPROCESSOR_FUNC(exponential), TL2CGEN_POSTPROCESSOR_FUNC(exponential_standard_ratio), - TL2CGEN_POSTPROCESSOR_FUNC(logarithm_one_plus_exp)}; - -// Postprocessor functions for *multi-class classifiers* -std::unordered_map const const postprocessor_multiclass_db - = {TL2CGEN_POSTPROCESSOR_FUNC(identity_multiclass), TL2CGEN_POSTPROCESSOR_FUNC(softmax), - TL2CGEN_POSTPROCESSOR_FUNC(multiclass_ova)}; - -} // anonymous namespace - -std::string tl2cgen::compiler::detail::PostProcessorFunction(Model const& model) { - TL2CGEN_CHECK_EQ(model.num_target, 1) << "TL2cgen does not yet support multi-target models"; - auto const num_class = model.num_class[0]; - if (num_class > 1) { // multi-class classification - auto it = postprocessor_multiclass_db.find(model.postprocessor); - if (it == postprocessor_multiclass_db.end()) { - std::ostringstream oss; - for (auto const& e : postprocessor_multiclass_db) { - oss << "'" << e.first << "', "; - } - TL2CGEN_LOG(FATAL) << "Invalid argument given for `postprocessor` parameter. " - << "For multi-class classification, you should set " - << "`postprocessor` to one of the following: " - << "{ " << oss.str() << " }"; - } - return (it->second)(model); - } else { - auto it = postprocessor_db.find(model.postprocessor); - if (it == postprocessor_db.end()) { - std::ostringstream oss; - for (auto const& e : postprocessor_db) { - oss << "'" << e.first << "', "; - } - TL2CGEN_LOG(FATAL) << "Invalid argument given for `postprocessor` parameter. " - << "For any task that is NOT multi-class classification, you " - << "should set `postprocessor` to one of the following: " - << "{ " << oss.str() << " }"; - } - return (it->second)(model); - } -} diff --git a/src/filesystem.cc b/src/filesystem.cc index 4f17495..b9f275e 100644 --- a/src/filesystem.cc +++ b/src/filesystem.cc @@ -24,14 +24,4 @@ void CreateDirectoryIfNotExist(std::filesystem::path const& dirpath) { } } -void WriteToFile(std::filesystem::path const& path, std::string const& content) { - std::ofstream of(path); - of << content; -} - -void WriteToFile(std::filesystem::path const& path, std::vector const& content) { - std::ofstream of(path, std::ios::out | std::ios::binary); - of.write(content.data(), content.size()); -} - } // namespace tl2cgen::detail::filesystem