Skip to content

Commit

Permalink
FEA fix init signature of RF for interop and add RF to accel
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Feb 6, 2025
1 parent 14de36c commit af5e420
Show file tree
Hide file tree
Showing 6 changed files with 447 additions and 5 deletions.
5 changes: 3 additions & 2 deletions python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import treelite.sklearn
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.global_settings import GlobalSettings

cp = gpu_only_import('cupy')
import math
Expand Down Expand Up @@ -92,7 +93,7 @@ class BaseRandomForestModel(UniversalBase):
"class_weight": class_weight}

for key, vals in sklearn_params.items():
if vals:
if vals and not GlobalSettings().accelerator_active:
raise TypeError(
" The Scikit-learn variable ", key,
" is not supported in cuML,"
Expand All @@ -101,7 +102,7 @@ class BaseRandomForestModel(UniversalBase):
"api.html#random-forest) for more information")

for key in kwargs.keys():
if key not in self._param_names:
if key not in self._param_names and not GlobalSettings().accelerator_active:
raise TypeError(
" The variable ", key,
" is not supported in cuML,"
Expand Down
51 changes: 50 additions & 1 deletion python/cuml/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# limitations under the License.
#


# distutils: language = c++
import sys

from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
Expand All @@ -31,7 +34,9 @@ rmm = gpu_only_import('rmm')

from cuml.internals.array import CumlArray
from cuml.internals.mixins import ClassifierMixin
from cuml.internals.global_settings import GlobalSettings
import cuml.internals
from cuml.internals import logger
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.common import input_to_cuml_array
Expand Down Expand Up @@ -252,6 +257,19 @@ class RandomForestClassifier(BaseRandomForestModel,

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier'

_hyperparam_interop_translator = {
"criterion": "NotImplemented",
"oob_score": {
True: "NotImplemented",
},
"max_depth": {
None: 16,
},
"max_samples": {
None: 1.0,
},
}

@device_interop_preparation
def __init__(self, *, split_criterion=0, handle=None, verbose=False,
output_type=None,
Expand Down Expand Up @@ -438,7 +456,6 @@ class RandomForestClassifier(BaseRandomForestModel,
y to be of dtype int32. This will increase memory used for
the method.
"""

X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y,
convert_dtype)
# Track the labels to see if update is necessary
Expand Down Expand Up @@ -838,3 +855,35 @@ class RandomForestClassifier(BaseRandomForestModel,
if self.dtype == np.float64:
return get_rf_json(rf_forest64).decode('utf-8')
return get_rf_json(rf_forest).decode('utf-8')

def cpu_to_gpu(self):
# treelite does an internal isinstance check to detect an sklearn
# RF, which proxymodule interferes with. We work around that
# temporarily here just for treelite internal check and
# restore the __class__ at the end of the method.
if GlobalSettings().accelerator_active:
cls_cahed = self._cpu_model.__class__
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestClassifier

super().cpu_to_gpu()

if GlobalSettings().accelerator_active:
self._cpu_model.__class__ = cls_cahed

@classmethod
def _hyperparam_translator(cls, **kwargs):
kwargs, gpuaccel = super(RandomForestClassifier, cls)._hyperparam_translator(**kwargs)

if "max_samples" in kwargs:
if isinstance(kwargs["max_samples"], int):
logger.warn(
f"Integer value of max_samples={kwargs['max_samples']}"
"not supported, changed to 1.0."
)
kwargs["max_samples"] = 1.0

# determinism requires only 1 cuda stream
if "random_state" in kwargs:
kwargs["n_streams"] = 1

return kwargs, gpuaccel
49 changes: 49 additions & 0 deletions python/cuml/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
# distutils: language = c++


import sys
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
Expand All @@ -28,7 +30,9 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
rmm = gpu_only_import('rmm')

from cuml.internals.array import CumlArray
from cuml.internals.global_settings import GlobalSettings
import cuml.internals
from cuml.internals import logger

from cuml.internals.mixins import RegressorMixin
from cuml.internals.logger cimport level_enum
Expand Down Expand Up @@ -254,6 +258,19 @@ class RandomForestRegressor(BaseRandomForestModel,

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestRegressor'

_hyperparam_interop_translator = {
"criterion": "NotImplemented",
"oob_score": {
True: "NotImplemented",
},
"max_depth": {
None: 16,
},
"max_samples": {
None: 1.0,
},
}

@device_interop_preparation
def __init__(self, *,
split_criterion=2,
Expand Down Expand Up @@ -761,3 +778,35 @@ class RandomForestRegressor(BaseRandomForestModel,
if self.dtype == np.float64:
return get_rf_json(rf_forest64).decode('utf-8')
return get_rf_json(rf_forest).decode('utf-8')

def cpu_to_gpu(self):
# treelite does an internal isinstance check to detect an sklearn
# RF, which proxymodule interferes with. We work around that
# temporarily here just for treelite internal check and
# restore the __class__ at the end of the method.
if GlobalSettings().accelerator_active:
cls_cahed = self._cpu_model.__class__
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestRegressor

super().cpu_to_gpu()

if GlobalSettings().accelerator_active:
self._cpu_model.__class__ = cls_cahed

@classmethod
def _hyperparam_translator(cls, **kwargs):
kwargs, gpuaccel = super(RandomForestRegressor, cls)._hyperparam_translator(**kwargs)

if "max_samples" in kwargs:
if isinstance(kwargs["max_samples"], int):
logger.warn(
f"Integer value of max_samples={kwargs['max_samples']}"
"not supported, changed to 1.0."
)
kwargs["max_samples"] = 1.0

# determinism requires only 1 cuda stream
if "random_state" in kwargs:
kwargs["n_streams"] = 1

return kwargs, gpuaccel
22 changes: 20 additions & 2 deletions python/cuml/cuml/internals/api_decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -365,7 +365,25 @@ def processor(self, *args, **kwargs):
# Save all kwargs
self._full_kwargs = kwargs
# Generate list of available cuML hyperparameters
gpu_hyperparams = list(inspect.signature(init_func).parameters.keys())

from cuml.ensemble.randomforest_common import BaseRandomForestModel

# Random Forest models init signature is a combination of the
# parameters in BaseRandomForest and the regressor/classifier classes
# so we need to join their init hyperparameters.
gpu_hyperparams = []
if isinstance(self, BaseRandomForestModel):
gpu_hyperparams.extend(
list(
inspect.signature(
BaseRandomForestModel.__init__
).parameters.keys()
)
)

gpu_hyperparams.extend(
list(inspect.signature(init_func).parameters.keys())
)

# Filter provided parameters for cuML estimator initialization
filtered_kwargs = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import pytest
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score


@pytest.fixture(scope="module")
def classification_data():
# Create a synthetic classification dataset.
X, y = make_classification(
n_samples=300, n_features=20, n_informative=10, n_redundant=5, random_state=42
)
return X, y


@pytest.mark.parametrize("n_estimators", [10, 50, 100])
def test_rf_n_estimators(classification_data, n_estimators):
X, y = classification_data
clf = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("criterion", ["gini", "entropy"])
def test_rf_criterion(classification_data, criterion):
X, y = classification_data
clf = RandomForestClassifier(criterion=criterion, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("max_depth", [None, 5, 10])
def test_rf_max_depth(classification_data, max_depth):
X, y = classification_data
clf = RandomForestClassifier(max_depth=max_depth, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("min_samples_split", [2, 5, 10])
def test_rf_min_samples_split(classification_data, min_samples_split):
X, y = classification_data
clf = RandomForestClassifier(min_samples_split=min_samples_split, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("min_samples_leaf", [1, 2, 4])
def test_rf_min_samples_leaf(classification_data, min_samples_leaf):
X, y = classification_data
clf = RandomForestClassifier(min_samples_leaf=min_samples_leaf, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("min_weight_fraction_leaf", [0.0, 0.1])
def test_rf_min_weight_fraction_leaf(classification_data, min_weight_fraction_leaf):
X, y = classification_data
clf = RandomForestClassifier(min_weight_fraction_leaf=min_weight_fraction_leaf, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("max_features", ["sqrt", "log2", 0.5, 5])
def test_rf_max_features(classification_data, max_features):
X, y = classification_data
clf = RandomForestClassifier(max_features=max_features, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("max_leaf_nodes", [None, 10, 20])
def test_rf_max_leaf_nodes(classification_data, max_leaf_nodes):
X, y = classification_data
clf = RandomForestClassifier(max_leaf_nodes=max_leaf_nodes, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("min_impurity_decrease", [0.0, 0.1])
def test_rf_min_impurity_decrease(classification_data, min_impurity_decrease):
X, y = classification_data
clf = RandomForestClassifier(min_impurity_decrease=min_impurity_decrease, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("bootstrap", [True, False])
def test_rf_bootstrap(classification_data, bootstrap):
X, y = classification_data
clf = RandomForestClassifier(bootstrap=bootstrap, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("n_jobs", [1, -1])
def test_rf_n_jobs(classification_data, n_jobs):
X, y = classification_data
clf = RandomForestClassifier(n_jobs=n_jobs, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("verbose", [0, 1])
def test_rf_verbose(classification_data, verbose):
X, y = classification_data
clf = RandomForestClassifier(verbose=verbose, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("warm_start", [False, True])
def test_rf_warm_start(classification_data, warm_start):
X, y = classification_data
clf = RandomForestClassifier(warm_start=warm_start, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("class_weight", [None, "balanced", {0: 1, 1: 2}])
def test_rf_class_weight(classification_data, class_weight):
X, y = classification_data
clf = RandomForestClassifier(class_weight=class_weight, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("ccp_alpha", [0.0, 0.1])
def test_rf_ccp_alpha(classification_data, ccp_alpha):
X, y = classification_data
clf = RandomForestClassifier(ccp_alpha=ccp_alpha, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


@pytest.mark.parametrize("max_samples", [None, 0.8, 50])
def test_rf_max_samples(classification_data, max_samples):
X, y = classification_data
clf = RandomForestClassifier(max_samples=max_samples, bootstrap=True, n_estimators=50, random_state=42)
clf.fit(X, y)
_ = accuracy_score(y, clf.predict(X))


def test_rf_random_state(classification_data):
X, y = classification_data
clf1 = RandomForestClassifier(n_estimators=50, random_state=42).fit(X, y)
clf2 = RandomForestClassifier(n_estimators=50, random_state=42).fit(X, y)
# Predictions should be identical with the same random_state.
assert np.array_equal(clf1.predict(X), clf2.predict(X))
Loading

0 comments on commit af5e420

Please sign in to comment.