Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt Treelite 4.1 + rewrite compiler #20

Merged
merged 18 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ extend-ignore =
F401
# redefinition of unused
F811
extend-ignore =
# E203 whitespace before ':'
# https://github.com/psf/black/issues/315
E203
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: ["--maxkb=4000"]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
args: [--config=.flake8]
Expand Down Expand Up @@ -50,12 +50,12 @@ repos:
language: python
args: [
"--linelength=100", "--recursive",
"--filter=-build/c++11,-build/include,-build/namespaces_literals,+build/include_what_you_use,+build/include_order",
"--filter=-build/c++11,-build/include,-build/namespaces_literals,-runtime/references,+build/include_what_you_use,-build/include_order",
"--root=include"]
additional_dependencies: [cpplint]
types_or: [c++]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.2.0
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: [types-setuptools]
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ endif ()
option(TEST_COVERAGE "C++ test coverage" OFF)
option(USE_OPENMP "Use OpenMP" ON)
option(BUILD_DOXYGEN "Build documentation for C/C++ functions using Doxygen." OFF)
option(BUILD_CPP_TESTS "Build C++ tests" OFF)
option(BUILD_CPP_TEST "Build C++ tests" OFF)
option(HIDE_CXX_SYMBOLS "Hide all C++ symbols. Useful when building Pip package" OFF)
option(DETECT_CONDA_ENV "Enable detection of Conda environment for dependencies" ON)
option(BUILD_JVM_RUNTIME "Build TL2cgen runtime for JVM" OFF)
Expand Down Expand Up @@ -63,7 +63,7 @@ add_subdirectory(src)
if (BUILD_JVM_RUNTIME)
add_subdirectory(java_runtime)
endif ()
if (BUILD_CPP_TESTS)
if (BUILD_CPP_TEST)
enable_testing()
add_subdirectory(tests/cpp)
endif ()
Expand Down
39 changes: 31 additions & 8 deletions cmake/ExternalLibs.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
include(FetchContent)

# Treelite
find_package(Treelite 3.4.0)
find_package(Treelite 4.1.0)
if (Treelite_FOUND)
set(TREELITE_FROM_SYSTEM_ROOT TRUE)
set(TREELITE_LIB treelite::treelite)
Expand All @@ -10,11 +10,11 @@ else ()
FetchContent_Declare(
treelite
GIT_REPOSITORY https://github.com/dmlc/treelite.git
GIT_TAG 3.9.0
GIT_TAG 4.1.0
)
set(Treelite_BUILD_STATIC_LIBS ON)
FetchContent_MakeAvailable(treelite)
set_target_properties(treelite treelite_runtime PROPERTIES EXCLUDE_FROM_ALL TRUE)
set_target_properties(treelite PROPERTIES EXCLUDE_FROM_ALL TRUE)
target_include_directories(treelite_static PUBLIC
$<BUILD_INTERFACE:${treelite_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${treelite_BINARY_DIR}/include>
Expand All @@ -34,7 +34,7 @@ else ()
FetchContent_Declare(
fmtlib
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 9.1.0
GIT_TAG 10.2.1
)
FetchContent_MakeAvailable(fmtlib)
set_target_properties(fmt PROPERTIES EXCLUDE_FROM_ALL TRUE)
Expand All @@ -44,6 +44,7 @@ endif ()
# RapidJSON (header-only library)
if (NOT TARGET rapidjson)
add_library(rapidjson INTERFACE)
target_compile_definitions(rapidjson INTERFACE -DRAPIDJSON_HAS_STDSTRING=1)
find_package(RapidJSON)
if (RapidJSON_FOUND)
target_include_directories(rapidjson INTERFACE ${RAPIDJSON_INCLUDE_DIRS})
Expand All @@ -61,10 +62,28 @@ if (NOT TARGET rapidjson)
add_library(RapidJSON::rapidjson ALIAS rapidjson)
endif ()

# mdspan (header-only library)
message(STATUS "Fetching mdspan...")
set(MDSPAN_CXX_STANDARD 17 CACHE STRING "")
FetchContent_Declare(
mdspan
GIT_REPOSITORY https://github.com/kokkos/mdspan.git
GIT_TAG mdspan-0.6.0
)
FetchContent_GetProperties(mdspan)
if(NOT mdspan_POPULATED)
FetchContent_Populate(mdspan)
add_subdirectory(${mdspan_SOURCE_DIR} ${mdspan_BINARY_DIR} EXCLUDE_FROM_ALL)
message(STATUS "mdspan was downloaded at ${mdspan_SOURCE_DIR}.")
endif()
if(MSVC) # workaround for MSVC 19.x: https://github.com/kokkos/mdspan/issues/276
target_compile_options(mdspan INTERFACE "/permissive-")
endif()

# Google C++ tests
if (BUILD_CPP_TESTS)
find_package(GTest 1.11.0 CONFIG)
if (NOT GTEST_FOUND)
if (BUILD_CPP_TEST)
find_package(GTest 1.11.0)
if (NOT GTest_FOUND)
message(STATUS "Did not find Google Test in the system root. Fetching Google Test now...")
set(gtest_force_shared_crt OFF CACHE BOOL "" FORCE)
FetchContent_Declare(
Expand All @@ -73,7 +92,11 @@ if (BUILD_CPP_TESTS)
GIT_TAG release-1.11.0
)
FetchContent_MakeAvailable(googletest)
add_library(GTest::GTest ALIAS gtest)
add_library(GTest::gtest ALIAS gtest)
add_library(GTest::gmock ALIAS gmock)
if(IS_DIRECTORY "${googletest_SOURCE_DIR}")
# Do not install gtest
set_property(DIRECTORY ${googletest_SOURCE_DIR} PROPERTY EXCLUDE_FROM_ALL YES)
endif()
endif ()
endif ()
2 changes: 2 additions & 0 deletions dev/change_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def update_pypkg(
is_dev: bool,
rc_ver: Optional[int] = None,
) -> None:
# pylint: disable=too-many-arguments
"""Change version in the Python package"""
version = f"{major}.{minor}.{patch}"
if is_rc:
Expand Down Expand Up @@ -66,6 +67,7 @@ def update_java_pkg(
is_dev: bool,
rc_ver: Optional[int] = None,
) -> None:
# pylint: disable=too-many-arguments
"""Change version in the Java package"""
version = f"{major}.{minor}.{patch}"
if is_rc:
Expand Down
2 changes: 1 addition & 1 deletion include/tl2cgen/annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class BranchAnnotator {
* \code
* Annotator annotator
* annotator.Load(fi); // load from a stream
* std::vector<std::vector<size_t>> annot = annotator.Get();
* std::vector<std::vector<std::uint64_t>> annot = annotator.Get();
* // access the frequency count for a specific node in a tree
* TL2CGEN_LOG(INFO) << "Tree " << tree_id << ", Node " << node_id << ": "
* << annot[tree_id][node_id];
Expand Down
Loading
Loading