Skip to content

Commit

Permalink
CPU/GPU interop with RandomForest (#6175)
Browse files Browse the repository at this point in the history
First version for CPU/GPU interop with RandomForest. Note. This feature requires latest Treelite.

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: #6175
  • Loading branch information
hcho3 authored Feb 7, 2025
1 parent f67b426 commit f60b5f0
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 14 deletions.
35 changes: 30 additions & 5 deletions python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import treelite.sklearn
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.global_settings import GlobalSettings

cp = gpu_only_import('cupy')
import math
import warnings
Expand All @@ -24,7 +29,7 @@ np = cpu_only_import('numpy')
from cuml import ForestInference
from cuml.fil.fil import TreeliteModel
from pylibraft.common.handle import Handle
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.internals.array import CumlArray
from cuml.common.exceptions import NotFittedError
import cuml.internals
Expand All @@ -39,7 +44,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.prims.label.classlabels import make_monotonic, check_labels


class BaseRandomForestModel(Base):
class BaseRandomForestModel(UniversalBase):
_param_names = ['n_estimators', 'max_depth', 'handle',
'max_features', 'n_bins',
'split_criterion', 'min_samples_leaf',
Expand Down Expand Up @@ -67,6 +72,7 @@ class BaseRandomForestModel(Base):

classes_ = CumlArrayDescriptor()

@device_interop_preparation
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='sqrt', n_bins=128,
bootstrap=True,
Expand All @@ -88,7 +94,7 @@ class BaseRandomForestModel(Base):
"class_weight": class_weight}

for key, vals in sklearn_params.items():
if vals:
if vals and not GlobalSettings().accelerator_active:
raise TypeError(
" The Scikit-learn variable ", key,
" is not supported in cuML,"
Expand All @@ -97,7 +103,7 @@ class BaseRandomForestModel(Base):
"api.html#random-forest) for more information")

for key in kwargs.keys():
if key not in self._param_names:
if key not in self._param_names and not GlobalSettings().accelerator_active:
raise TypeError(
" The variable ", key,
" is not supported in cuML,"
Expand Down Expand Up @@ -154,6 +160,7 @@ class BaseRandomForestModel(Base):
self.model_pbuf_bytes = bytearray()
self.treelite_handle = None
self.treelite_serialized_model = None
self._cpu_model_class_lock = threading.RLock()

def _get_max_feat_val(self) -> float:
if isinstance(self.max_features, int):
Expand Down Expand Up @@ -268,6 +275,24 @@ class BaseRandomForestModel(Base):
self.treelite_handle = <uintptr_t> tl_handle
return self.treelite_handle

def cpu_to_gpu(self):
tl_model = treelite.sklearn.import_model(self._cpu_model)
self._temp = TreeliteModel.from_treelite_bytes(tl_model.serialize_bytes())
self.treelite_serialized_model = treelite_serialize(self._temp.handle)
self._obtain_treelite_handle()
self.dtype = np.float64
self.update_labels = False
super().cpu_to_gpu()

def gpu_to_cpu(self):
self._obtain_treelite_handle()
tl_model = TreeliteModel.from_treelite_model_handle(
self.treelite_handle,
take_handle_ownership=False)
tl_bytes = tl_model.to_treelite_bytes()
tl_model2 = treelite.Model.deserialize_bytes(tl_bytes)
self._cpu_model = treelite.sklearn.export_model(tl_model2)

@cuml.internals.api_base_return_generic(set_output_type=True,
set_n_features_in=True,
get_output_type=False)
Expand Down
71 changes: 70 additions & 1 deletion python/cuml/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
# limitations under the License.
#


# distutils: language = c++
import sys
import threading

from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand All @@ -29,7 +35,9 @@ rmm = gpu_only_import('rmm')

from cuml.internals.array import CumlArray
from cuml.internals.mixins import ClassifierMixin
from cuml.internals.global_settings import GlobalSettings
import cuml.internals
from cuml.internals import logger
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.common import input_to_cuml_array
Expand Down Expand Up @@ -248,6 +256,22 @@ class RandomForestClassifier(BaseRandomForestModel,
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_.
"""

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier'

_hyperparam_interop_translator = {
"criterion": "NotImplemented",
"oob_score": {
True: "NotImplemented",
},
"max_depth": {
None: 16,
},
"max_samples": {
None: 1.0,
},
}

@device_interop_preparation
def __init__(self, *, split_criterion=0, handle=None, verbose=False,
output_type=None,
**kwargs):
Expand Down Expand Up @@ -292,6 +316,10 @@ class RandomForestClassifier(BaseRandomForestModel,
state["treelite_handle"] = None
state["split_criterion"] = self.split_criterion
state["handle"] = self.handle

if "_cpu_model_class_lock" in state:
del state["_cpu_model_class_lock"]

return state

def __setstate__(self, state):
Expand All @@ -314,6 +342,7 @@ class RandomForestClassifier(BaseRandomForestModel,

self.treelite_serialized_model = state["treelite_serialized_model"]
self.__dict__.update(state)
self._cpu_model_class_lock = threading.RLock()

def __del__(self):
self._reset_forest_data()
Expand All @@ -338,6 +367,9 @@ class RandomForestClassifier(BaseRandomForestModel,
self.treelite_serialized_model = None
self.n_cols = None

def get_attr_names(self):
return []

def convert_to_treelite_model(self):
"""
Converts the cuML RF model to a Treelite model
Expand Down Expand Up @@ -418,6 +450,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@cuml.internals.api_base_return_any(set_output_type=False,
set_output_dtype=True,
set_n_features_in=False)
@enable_device_interop
def fit(self, X, y, convert_dtype=True):
"""
Perform Random Forest Classification on the input data
Expand All @@ -429,7 +462,6 @@ class RandomForestClassifier(BaseRandomForestModel,
y to be of dtype int32. This will increase memory used for
the method.
"""

X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y,
convert_dtype)
# Track the labels to see if update is necessary
Expand Down Expand Up @@ -556,6 +588,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@cuml.internals.api_base_return_array(get_output_dtype=True)
@enable_device_interop
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down Expand Up @@ -828,3 +861,39 @@ class RandomForestClassifier(BaseRandomForestModel,
if self.dtype == np.float64:
return get_rf_json(rf_forest64).decode('utf-8')
return get_rf_json(rf_forest).decode('utf-8')

def cpu_to_gpu(self):
# treelite does an internal isinstance check to detect an sklearn
# RF, which proxymodule interferes with. We work around that
# temporarily here just for treelite internal check and
# restore the __class__ at the end of the method.
if GlobalSettings().accelerator_active:
with self._cpu_model_class_lock:
original_class = self._cpu_model.__class__
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestClassifier

try:
super().cpu_to_gpu()
finally:
self._cpu_model.__class__ = original_class

else:
super().cpu_to_gpu()

@classmethod
def _hyperparam_translator(cls, **kwargs):
kwargs, gpuaccel = super(RandomForestClassifier, cls)._hyperparam_translator(**kwargs)

if "max_samples" in kwargs:
if isinstance(kwargs["max_samples"], int):
logger.warn(
f"Integer value of max_samples={kwargs['max_samples']}"
"not supported, changed to 1.0."
)
kwargs["max_samples"] = 1.0

# determinism requires only 1 cuda stream
if "random_state" in kwargs:
kwargs["n_streams"] = 1

return kwargs, gpuaccel
70 changes: 69 additions & 1 deletion python/cuml/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# distutils: language = c++


import sys
import threading

from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand All @@ -27,7 +32,9 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
rmm = gpu_only_import('rmm')

from cuml.internals.array import CumlArray
from cuml.internals.global_settings import GlobalSettings
import cuml.internals
from cuml.internals import logger

from cuml.internals.mixins import RegressorMixin
from cuml.internals.logger cimport level_enum
Expand Down Expand Up @@ -251,6 +258,22 @@ class RandomForestRegressor(BaseRandomForestModel,
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html>`_.
"""

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestRegressor'

_hyperparam_interop_translator = {
"criterion": "NotImplemented",
"oob_score": {
True: "NotImplemented",
},
"max_depth": {
None: 16,
},
"max_samples": {
None: 1.0,
},
}

@device_interop_preparation
def __init__(self, *,
split_criterion=2,
accuracy_metric='r2',
Expand Down Expand Up @@ -297,6 +320,9 @@ class RandomForestRegressor(BaseRandomForestModel,
state["treelite_handle"] = None
state["split_criterion"] = self.split_criterion

if "_cpu_model_class_lock" in state:
del state["_cpu_model_class_lock"]

return state

def __setstate__(self, state):
Expand All @@ -318,6 +344,7 @@ class RandomForestRegressor(BaseRandomForestModel,

self.treelite_serialized_model = state["treelite_serialized_model"]
self.__dict__.update(state)
self._cpu_model_class_lock = threading.RLock()

def __del__(self):
self._reset_forest_data()
Expand All @@ -342,6 +369,9 @@ class RandomForestRegressor(BaseRandomForestModel,
self.treelite_serialized_model = None
self.n_cols = None

def get_attr_names(self):
return []

def convert_to_treelite_model(self):
"""
Converts the cuML RF model to a Treelite model
Expand Down Expand Up @@ -413,6 +443,7 @@ class RandomForestRegressor(BaseRandomForestModel,
domain="cuml_python")
@generate_docstring()
@cuml.internals.api_base_return_any_skipall
@enable_device_interop
def fit(self, X, y, convert_dtype=True):
"""
Perform Random Forest Regression on the input data
Expand Down Expand Up @@ -535,6 +566,7 @@ class RandomForestRegressor(BaseRandomForestModel,
domain="cuml_python")
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@enable_device_interop
def predict(self, X, predict_model="GPU",
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down Expand Up @@ -752,3 +784,39 @@ class RandomForestRegressor(BaseRandomForestModel,
if self.dtype == np.float64:
return get_rf_json(rf_forest64).decode('utf-8')
return get_rf_json(rf_forest).decode('utf-8')

def cpu_to_gpu(self):
# treelite does an internal isinstance check to detect an sklearn
# RF, which proxymodule interferes with. We work around that
# temporarily here just for treelite internal check and
# restore the __class__ at the end of the method.
if GlobalSettings().accelerator_active:
with self._cpu_model_class_lock:
original_class = self._cpu_model.__class__
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestRegressor

try:
super().cpu_to_gpu()
finally:
self._cpu_model.__class__ = original_class

else:
super().cpu_to_gpu()

@classmethod
def _hyperparam_translator(cls, **kwargs):
kwargs, gpuaccel = super(RandomForestRegressor, cls)._hyperparam_translator(**kwargs)

if "max_samples" in kwargs:
if isinstance(kwargs["max_samples"], int):
logger.warn(
f"Integer value of max_samples={kwargs['max_samples']}"
"not supported, changed to 1.0."
)
kwargs["max_samples"] = 1.0

# determinism requires only 1 cuda stream
if "random_state" in kwargs:
kwargs["n_streams"] = 1

return kwargs, gpuaccel
Loading

0 comments on commit f60b5f0

Please sign in to comment.