diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 1f9b284a1b..800dcb4689 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -13,6 +13,7 @@ jobs: # Please keep pr-builder as the top job here pr-builder: needs: + - check-nightly-ci - changed-files - checks - clang-tidy @@ -43,6 +44,18 @@ jobs: - name: Telemetry setup if: ${{ vars.TELEMETRY_ENABLED == 'true' }} uses: rapidsai/shared-actions/telemetry-dispatch-stash-base-env-vars@main + check-nightly-ci: + # Switch to ubuntu-latest once it defaults to a version of Ubuntu that + # provides at least Python 3.11 (see + # https://docs.python.org/3/library/datetime.html#datetime.date.fromisoformat) + runs-on: ubuntu-24.04 + env: + RAPIDS_GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Check if nightly CI is passing + uses: rapidsai/shared-actions/check_nightly_success/dispatch@main + with: + repo: cuml changed-files: secrets: inherit needs: telemetry-setup diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 3c840d9849..c7c7eccfb9 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -16,7 +16,7 @@ cd "${package_dir}" sccache --zero-stats rapids-logger "Building '${package_name}' wheel" -python -m pip wheel \ +rapids-pip-retry wheel \ -w dist \ -v \ --no-deps \ diff --git a/ci/build_wheel_libcuml.sh b/ci/build_wheel_libcuml.sh index ad38eab617..7c719a6380 100755 --- a/ci/build_wheel_libcuml.sh +++ b/ci/build_wheel_libcuml.sh @@ -18,7 +18,7 @@ rapids-dependency-file-generator \ | tee /tmp/requirements-build.txt rapids-logger "Installing build requirements" -python -m pip install \ +rapids-pip-retry install \ -v \ --prefer-binary \ -r /tmp/requirements-build.txt diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index 8027876005..7e1feff80a 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -9,7 +9,7 @@ RAPIDS_PY_WHEEL_NAME="cuml_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from RAPIDS_PY_WHEEL_NAME="libcuml_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 cpp ./dist # echo to expand wildcard before adding `[extra]` requires for pip -python -m pip install \ +rapids-pip-retry install \ ./dist/libcuml*.whl \ "$(echo ./dist/cuml*.whl)[test]" diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index c99003758e..f21729a695 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -118,31 +118,42 @@ inline void launcher(const raft::handle_t& handle, // TODO: use nndescent from cuvs RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + RAFT_EXPECTS(params->nn_descent_params.return_distances, + "return_distances for nn descent should be set to true to be used for UMAP"); auto graph = get_graph_nnd(handle, inputsA, params); - auto indices_d = raft::make_device_matrix( - handle, inputsA.n, params->nn_descent_params.graph_degree); - - raft::copy(indices_d.data_handle(), - graph.graph().data_handle(), - inputsA.n * params->nn_descent_params.graph_degree, - stream); - + // `graph.graph()` is a host array (n x graph_degree). + // Slice and copy to a temporary host array (n x n_neighbors), then copy + // that to the output device array `out.knn_indices` (n x n_neighbors). + // TODO: force graph_degree = n_neighbors so the temporary host array and + // slice isn't necessary. + auto temp_indices_h = raft::make_host_matrix(inputsA.n, n_neighbors); + size_t graph_degree = params->nn_descent_params.graph_degree; +#pragma omp parallel for + for (size_t i = 0; i < static_cast(inputsA.n); i++) { + for (int j = 0; j < n_neighbors; j++) { + auto target = temp_indices_h.data_handle(); + auto source = graph.graph().data_handle(); + target[i * n_neighbors + j] = source[i * graph_degree + j]; + } + } + raft::copy(handle, + raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors), + temp_indices_h.view()); + + // `graph.distances()` is a device array (n x graph_degree). + // Slice and copy to the output device array `out.knn_dists` (n x n_neighbors). + // TODO: force graph_degree = n_neighbors so this slice isn't necessary. raft::matrix::slice_coordinates coords{static_cast(0), static_cast(0), static_cast(inputsA.n), static_cast(n_neighbors)}; - - RAFT_EXPECTS(graph.distances().has_value(), - "return_distances for nn descent should be set to true to be used for UMAP"); - auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); raft::matrix::slice( - handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); - auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors); - raft::matrix::slice( - handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); + handle, + raft::make_const_mdspan(graph.distances().value()), + raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors), + coords); } } diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index ed26df5cd6..dac4c9ded7 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -21,6 +21,7 @@ np = cpu_only_import('numpy') from cuml.internals.safe_imports import gpu_only_import rmm = gpu_only_import('rmm') from cuml.internals.safe_imports import safe_import_from, return_false +from cuml.internals.utils import check_random_seed import typing IF GPUBUILD == 1: @@ -209,8 +210,11 @@ class KMeans(UniversalBase, params.init = self._params_init params.max_iter = self.max_iter params.tol = self.tol + # After transferring from one device to another `_seed` might not be set + # so we need to pass a dummy value here. Its value does not matter as the + # seed is only used during fitting + params.rng_state.seed = getattr(self, "_seed", 0) params.verbosity = (self.verbose) - params.rng_state.seed = self.random_state params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 params.batch_samples = self.max_samples_per_batch params.oversampling_factor = self.oversampling_factor @@ -307,6 +311,7 @@ class KMeans(UniversalBase, else None), check_dtype=check_dtype) + self._seed = check_random_seed(self.random_state) self.feature_names_in_ = _X_m.index IF GPUBUILD == 1: diff --git a/python/cuml/cuml/cluster/kmeans_mg.pyx b/python/cuml/cuml/cluster/kmeans_mg.pyx index cf0a2967c4..1294467755 100644 --- a/python/cuml/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cuml/cluster/kmeans_mg.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ from cuml.common import input_to_cuml_array from cuml.cluster import KMeans from cuml.cluster.kmeans_utils cimport params as KMeansParams +from cuml.internals.utils import check_random_seed cdef extern from "cuml/cluster/kmeans_mg.hpp" \ @@ -129,6 +130,8 @@ class KMeansMG(KMeans): cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr + self._seed = check_random_seed(self.random_state) + if (self.init in ['scalable-k-means++', 'k-means||', 'random']): self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters, self.n_cols), diff --git a/python/cuml/cuml/decomposition/pca.pyx b/python/cuml/cuml/decomposition/pca.pyx index db2f0f62c8..9a055e6505 100644 --- a/python/cuml/cuml/decomposition/pca.pyx +++ b/python/cuml/cuml/decomposition/pca.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,9 +209,6 @@ class PCA(UniversalBase, ``n_components = min(n_samples, n_features)`` - random_state : int / None (default = None) - If you want results to be the same when you restart Python, select a - state. svd_solver : 'full' or 'jacobi' or 'auto' (default = 'full') Full uses a eigendecomposition of the covariance matrix then discards components. @@ -292,7 +289,7 @@ class PCA(UniversalBase, @device_interop_preparation def __init__(self, *, copy=True, handle=None, iterated_power=15, - n_components=None, random_state=None, svd_solver='auto', + n_components=None, svd_solver='auto', tol=1e-7, verbose=False, whiten=False, output_type=None): # parameters @@ -302,7 +299,6 @@ class PCA(UniversalBase, self.copy = copy self.iterated_power = iterated_power self.n_components = n_components - self.random_state = random_state self.svd_solver = svd_solver self.tol = tol self.whiten = whiten @@ -739,7 +735,7 @@ class PCA(UniversalBase, def _get_param_names(cls): return super()._get_param_names() + \ ["copy", "iterated_power", "n_components", "svd_solver", "tol", - "whiten", "random_state"] + "whiten"] def _check_is_fitted(self, attr): if not hasattr(self, attr) or (getattr(self, attr) is None): diff --git a/python/cuml/cuml/ensemble/randomforest_common.pyx b/python/cuml/cuml/ensemble/randomforest_common.pyx index 38c15eaca2..46dd790dae 100644 --- a/python/cuml/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/cuml/ensemble/randomforest_common.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import threading +import treelite.sklearn from cuml.internals.safe_imports import gpu_only_import +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.global_settings import GlobalSettings + cp = gpu_only_import('cupy') import math import warnings @@ -24,7 +29,7 @@ np = cpu_only_import('numpy') from cuml import ForestInference from cuml.fil.fil import TreeliteModel from pylibraft.common.handle import Handle -from cuml.internals.base import Base +from cuml.internals.base import UniversalBase from cuml.internals.array import CumlArray from cuml.common.exceptions import NotFittedError import cuml.internals @@ -39,7 +44,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.prims.label.classlabels import make_monotonic, check_labels -class BaseRandomForestModel(Base): +class BaseRandomForestModel(UniversalBase): _param_names = ['n_estimators', 'max_depth', 'handle', 'max_features', 'n_bins', 'split_criterion', 'min_samples_leaf', @@ -67,6 +72,7 @@ class BaseRandomForestModel(Base): classes_ = CumlArrayDescriptor() + @device_interop_preparation def __init__(self, *, split_criterion, n_streams=4, n_estimators=100, max_depth=16, handle=None, max_features='sqrt', n_bins=128, bootstrap=True, @@ -88,7 +94,7 @@ class BaseRandomForestModel(Base): "class_weight": class_weight} for key, vals in sklearn_params.items(): - if vals: + if vals and not GlobalSettings().accelerator_active: raise TypeError( " The Scikit-learn variable ", key, " is not supported in cuML," @@ -97,7 +103,7 @@ class BaseRandomForestModel(Base): "api.html#random-forest) for more information") for key in kwargs.keys(): - if key not in self._param_names: + if key not in self._param_names and not GlobalSettings().accelerator_active: raise TypeError( " The variable ", key, " is not supported in cuML," @@ -154,6 +160,7 @@ class BaseRandomForestModel(Base): self.model_pbuf_bytes = bytearray() self.treelite_handle = None self.treelite_serialized_model = None + self._cpu_model_class_lock = threading.RLock() def _get_max_feat_val(self) -> float: if isinstance(self.max_features, int): @@ -268,6 +275,24 @@ class BaseRandomForestModel(Base): self.treelite_handle = tl_handle return self.treelite_handle + def cpu_to_gpu(self): + tl_model = treelite.sklearn.import_model(self._cpu_model) + self._temp = TreeliteModel.from_treelite_bytes(tl_model.serialize_bytes()) + self.treelite_serialized_model = treelite_serialize(self._temp.handle) + self._obtain_treelite_handle() + self.dtype = np.float64 + self.update_labels = False + super().cpu_to_gpu() + + def gpu_to_cpu(self): + self._obtain_treelite_handle() + tl_model = TreeliteModel.from_treelite_model_handle( + self.treelite_handle, + take_handle_ownership=False) + tl_bytes = tl_model.to_treelite_bytes() + tl_model2 = treelite.Model.deserialize_bytes(tl_bytes) + self._cpu_model = treelite.sklearn.export_model(tl_model2) + @cuml.internals.api_base_return_generic(set_output_type=True, set_n_features_in=True, get_output_type=False) diff --git a/python/cuml/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/cuml/ensemble/randomforestclassifier.pyx index 5198d60b28..f4e926d0cc 100644 --- a/python/cuml/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/cuml/ensemble/randomforestclassifier.pyx @@ -15,7 +15,13 @@ # limitations under the License. # + # distutils: language = c++ +import sys +import threading + +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.api_decorators import enable_device_interop from cuml.internals.safe_imports import ( cpu_only_import, gpu_only_import, @@ -29,10 +35,13 @@ rmm = gpu_only_import('rmm') from cuml.internals.array import CumlArray from cuml.internals.mixins import ClassifierMixin +from cuml.internals.global_settings import GlobalSettings import cuml.internals +from cuml.internals import logger from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array +from cuml.internals.utils import check_random_seed from cuml.internals.logger cimport level_enum from cuml.ensemble.randomforest_common import BaseRandomForestModel @@ -247,6 +256,22 @@ class RandomForestClassifier(BaseRandomForestModel, `_. """ + _cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier' + + _hyperparam_interop_translator = { + "criterion": "NotImplemented", + "oob_score": { + True: "NotImplemented", + }, + "max_depth": { + None: 16, + }, + "max_samples": { + None: 1.0, + }, + } + + @device_interop_preparation def __init__(self, *, split_criterion=0, handle=None, verbose=False, output_type=None, **kwargs): @@ -291,6 +316,10 @@ class RandomForestClassifier(BaseRandomForestModel, state["treelite_handle"] = None state["split_criterion"] = self.split_criterion state["handle"] = self.handle + + if "_cpu_model_class_lock" in state: + del state["_cpu_model_class_lock"] + return state def __setstate__(self, state): @@ -313,6 +342,7 @@ class RandomForestClassifier(BaseRandomForestModel, self.treelite_serialized_model = state["treelite_serialized_model"] self.__dict__.update(state) + self._cpu_model_class_lock = threading.RLock() def __del__(self): self._reset_forest_data() @@ -337,6 +367,9 @@ class RandomForestClassifier(BaseRandomForestModel, self.treelite_serialized_model = None self.n_cols = None + def get_attr_names(self): + return [] + def convert_to_treelite_model(self): """ Converts the cuML RF model to a Treelite model @@ -417,6 +450,7 @@ class RandomForestClassifier(BaseRandomForestModel, @cuml.internals.api_base_return_any(set_output_type=False, set_output_dtype=True, set_n_features_in=False) + @enable_device_interop def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Classification on the input data @@ -428,7 +462,6 @@ class RandomForestClassifier(BaseRandomForestModel, y to be of dtype int32. This will increase memory used for the method. """ - X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y, convert_dtype) # Track the labels to see if update is necessary @@ -451,7 +484,7 @@ class RandomForestClassifier(BaseRandomForestModel, if self.random_state is None: seed_val = NULL else: - seed_val = self.random_state + seed_val = check_random_seed(self.random_state) rf_params = set_rf_params( self.max_depth, self.max_leaves, @@ -555,6 +588,7 @@ class RandomForestClassifier(BaseRandomForestModel, @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) @cuml.internals.api_base_return_array(get_output_dtype=True) + @enable_device_interop def predict(self, X, predict_model="GPU", threshold=0.5, algo='auto', convert_dtype=True, fil_sparse_format='auto') -> CumlArray: @@ -827,3 +861,39 @@ class RandomForestClassifier(BaseRandomForestModel, if self.dtype == np.float64: return get_rf_json(rf_forest64).decode('utf-8') return get_rf_json(rf_forest).decode('utf-8') + + def cpu_to_gpu(self): + # treelite does an internal isinstance check to detect an sklearn + # RF, which proxymodule interferes with. We work around that + # temporarily here just for treelite internal check and + # restore the __class__ at the end of the method. + if GlobalSettings().accelerator_active: + with self._cpu_model_class_lock: + original_class = self._cpu_model.__class__ + self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestClassifier + + try: + super().cpu_to_gpu() + finally: + self._cpu_model.__class__ = original_class + + else: + super().cpu_to_gpu() + + @classmethod + def _hyperparam_translator(cls, **kwargs): + kwargs, gpuaccel = super(RandomForestClassifier, cls)._hyperparam_translator(**kwargs) + + if "max_samples" in kwargs: + if isinstance(kwargs["max_samples"], int): + logger.warn( + f"Integer value of max_samples={kwargs['max_samples']}" + "not supported, changed to 1.0." + ) + kwargs["max_samples"] = 1.0 + + # determinism requires only 1 cuda stream + if "random_state" in kwargs: + kwargs["n_streams"] = 1 + + return kwargs, gpuaccel diff --git a/python/cuml/cuml/ensemble/randomforestregressor.pyx b/python/cuml/cuml/ensemble/randomforestregressor.pyx index 6e3a13d0fb..fac26b0dc0 100644 --- a/python/cuml/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/cuml/ensemble/randomforestregressor.pyx @@ -13,9 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - # distutils: language = c++ + +import sys +import threading + +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.api_decorators import enable_device_interop from cuml.internals.safe_imports import ( cpu_only_import, gpu_only_import, @@ -27,13 +32,16 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator) rmm = gpu_only_import('rmm') from cuml.internals.array import CumlArray +from cuml.internals.global_settings import GlobalSettings import cuml.internals +from cuml.internals import logger from cuml.internals.mixins import RegressorMixin from cuml.internals.logger cimport level_enum from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array +from cuml.internals.utils import check_random_seed from cuml.ensemble.randomforest_common import BaseRandomForestModel from cuml.ensemble.randomforest_common import _obtain_fil_model @@ -250,6 +258,22 @@ class RandomForestRegressor(BaseRandomForestModel, `_. """ + _cpu_estimator_import_path = 'sklearn.ensemble.RandomForestRegressor' + + _hyperparam_interop_translator = { + "criterion": "NotImplemented", + "oob_score": { + True: "NotImplemented", + }, + "max_depth": { + None: 16, + }, + "max_samples": { + None: 1.0, + }, + } + + @device_interop_preparation def __init__(self, *, split_criterion=2, accuracy_metric='r2', @@ -296,6 +320,9 @@ class RandomForestRegressor(BaseRandomForestModel, state["treelite_handle"] = None state["split_criterion"] = self.split_criterion + if "_cpu_model_class_lock" in state: + del state["_cpu_model_class_lock"] + return state def __setstate__(self, state): @@ -317,6 +344,7 @@ class RandomForestRegressor(BaseRandomForestModel, self.treelite_serialized_model = state["treelite_serialized_model"] self.__dict__.update(state) + self._cpu_model_class_lock = threading.RLock() def __del__(self): self._reset_forest_data() @@ -341,6 +369,9 @@ class RandomForestRegressor(BaseRandomForestModel, self.treelite_serialized_model = None self.n_cols = None + def get_attr_names(self): + return [] + def convert_to_treelite_model(self): """ Converts the cuML RF model to a Treelite model @@ -412,6 +443,7 @@ class RandomForestRegressor(BaseRandomForestModel, domain="cuml_python") @generate_docstring() @cuml.internals.api_base_return_any_skipall + @enable_device_interop def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Regression on the input data @@ -438,7 +470,7 @@ class RandomForestRegressor(BaseRandomForestModel, if self.random_state is None: seed_val = NULL else: - seed_val = self.random_state + seed_val = check_random_seed(self.random_state) rf_params = set_rf_params( self.max_depth, self.max_leaves, @@ -534,6 +566,7 @@ class RandomForestRegressor(BaseRandomForestModel, domain="cuml_python") @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) + @enable_device_interop def predict(self, X, predict_model="GPU", algo='auto', convert_dtype=True, fil_sparse_format='auto') -> CumlArray: @@ -751,3 +784,39 @@ class RandomForestRegressor(BaseRandomForestModel, if self.dtype == np.float64: return get_rf_json(rf_forest64).decode('utf-8') return get_rf_json(rf_forest).decode('utf-8') + + def cpu_to_gpu(self): + # treelite does an internal isinstance check to detect an sklearn + # RF, which proxymodule interferes with. We work around that + # temporarily here just for treelite internal check and + # restore the __class__ at the end of the method. + if GlobalSettings().accelerator_active: + with self._cpu_model_class_lock: + original_class = self._cpu_model.__class__ + self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestRegressor + + try: + super().cpu_to_gpu() + finally: + self._cpu_model.__class__ = original_class + + else: + super().cpu_to_gpu() + + @classmethod + def _hyperparam_translator(cls, **kwargs): + kwargs, gpuaccel = super(RandomForestRegressor, cls)._hyperparam_translator(**kwargs) + + if "max_samples" in kwargs: + if isinstance(kwargs["max_samples"], int): + logger.warn( + f"Integer value of max_samples={kwargs['max_samples']}" + "not supported, changed to 1.0." + ) + kwargs["max_samples"] = 1.0 + + # determinism requires only 1 cuda stream + if "random_state" in kwargs: + kwargs["n_streams"] = 1 + + return kwargs, gpuaccel diff --git a/python/cuml/cuml/fil/fil.pyx b/python/cuml/cuml/fil/fil.pyx index d3764fc758..617cd0c302 100644 --- a/python/cuml/cuml/fil/fil.pyx +++ b/python/cuml/cuml/fil/fil.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -72,6 +72,7 @@ cdef extern from "treelite/c_api.h": const char* filename) except + cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len, TreeliteModelHandle* out) except + + cdef int TreeliteSerializeModelToBytes(TreeliteModelHandle handle, const char** out_bytes, size_t* out_bytes_len) cdef int TreeliteGetHeaderField( TreeliteModelHandle model, const char * name, TreelitePyBufferFrame* out_frame) except + cdef const char* TreeliteGetLastError() @@ -192,6 +193,17 @@ cdef class TreeliteModel(): model.set_handle(handle) return model + def to_treelite_bytes(self) -> bytes: + assert self.handle != NULL + cdef const char* out_bytes + cdef size_t out_bytes_len + cdef int res = TreeliteSerializeModelToBytes(self.handle, &out_bytes, &out_bytes_len) + cdef str err_msg + if res < 0: + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to serialize Treelite model ({err_msg})") + return out_bytes[:out_bytes_len] + @classmethod def from_filename(cls, filename, model_type="xgboost_ubj"): """ diff --git a/python/cuml/cuml/internals/api_decorators.py b/python/cuml/cuml/internals/api_decorators.py index 98e3bb181d..f37ff7df97 100644 --- a/python/cuml/cuml/internals/api_decorators.py +++ b/python/cuml/cuml/internals/api_decorators.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -365,7 +365,25 @@ def processor(self, *args, **kwargs): # Save all kwargs self._full_kwargs = kwargs # Generate list of available cuML hyperparameters - gpu_hyperparams = list(inspect.signature(init_func).parameters.keys()) + + from cuml.ensemble.randomforest_common import BaseRandomForestModel + + # Random Forest models init signature is a combination of the + # parameters in BaseRandomForest and the regressor/classifier classes + # so we need to join their init hyperparameters. + gpu_hyperparams = [] + if isinstance(self, BaseRandomForestModel): + gpu_hyperparams.extend( + list( + inspect.signature( + BaseRandomForestModel.__init__ + ).parameters.keys() + ) + ) + + gpu_hyperparams.extend( + list(inspect.signature(init_func).parameters.keys()) + ) # Filter provided parameters for cuML estimator initialization filtered_kwargs = {} diff --git a/python/cuml/cuml/internals/utils.py b/python/cuml/cuml/internals/utils.py new file mode 100644 index 0000000000..67b2586e26 --- /dev/null +++ b/python/cuml/cuml/internals/utils.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numbers +import numpy as np + + +def check_random_seed(seed): + """Turn a np.random.RandomState instance into a seed. + Parameters + ---------- + seed : None | int | instance of RandomState + If seed is None, return a random int as seed. + If seed is an int, return it. + If seed is a RandomState instance, derive a seed from it. + Otherwise raise ValueError. + """ + if seed is None: + seed = np.random.RandomState(None) + + if isinstance(seed, numbers.Integral): + return seed + if isinstance(seed, np.random.RandomState): + return seed.randint( + low=0, high=np.iinfo(np.uint32).max, dtype=np.uint32 + ) + raise ValueError("%r cannot be used to create a seed." % seed) diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index 08ec39913a..1d8d13ab9e 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -31,10 +31,10 @@ from cuml.internals.base import UniversalBase from pylibraft.common.handle cimport handle_t from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop +from cuml.internals.utils import check_random_seed from cuml.internals import logger from cuml.internals cimport logger - from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray from cuml.common.sparse_utils import is_sparse @@ -596,7 +596,7 @@ class TSNE(UniversalBase, def _build_tsne_params(self, algo): cdef long long seed = -1 if self.random_state is not None: - seed = self.random_state + seed = check_random_seed(self.random_state) cdef TSNEParams* params = new TSNEParams() params.dim = self.n_components diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 079b270d0a..f047944b6d 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -47,6 +47,7 @@ from cuml.internals.array_sparse import SparseCumlArray from cuml.internals.mem_type import MemoryType from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.internals.utils import check_random_seed from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.internals.api_decorators import device_interop_preparation @@ -334,6 +335,22 @@ class UMAP(UniversalBase, _cpu_estimator_import_path = 'umap.UMAP' embedding_ = CumlArrayDescriptor(order='C') + _hyperparam_interop_translator = { + "metric": { + "sokalsneath": "NotImplemented", + "rogerstanimoto": "NotImplemented", + "sokalmichener": "NotImplemented", + "yule": "NotImplemented", + "ll_dirichlet": "NotImplemented", + "russellrao": "NotImplemented", + "kulsinski": "NotImplemented", + "dice": "NotImplemented", + "wminkowski": "NotImplemented", + "mahalanobis": "NotImplemented", + "haversine": "NotImplemented", + } + } + @device_interop_preparation def __init__(self, *, n_neighbors=15, @@ -401,22 +418,7 @@ class UMAP(UniversalBase, self.deterministic = random_state is not None - # Check to see if we are already a random_state (type==np.uint64). - # Reuse this if already passed (can happen from get_params() of another - # instance) - if isinstance(random_state, np.uint64): - self.random_state = random_state - else: - # Otherwise create a RandomState instance to generate a new - # np.uint64 - if isinstance(random_state, np.random.RandomState): - rs = random_state - else: - rs = np.random.RandomState(random_state) - - self.random_state = rs.randint(low=0, - high=np.iinfo(np.uint32).max, - dtype=np.uint32) + self.random_state = random_state if target_metric == "euclidean" or target_metric == "categorical": self.target_metric = target_metric @@ -467,38 +469,37 @@ class UMAP(UniversalBase, if self.min_dist > self.spread: raise ValueError("min_dist should be <= spread") - @staticmethod - def _build_umap_params(cls, sparse): + def _build_umap_params(self, sparse): IF GPUBUILD == 1: cdef UMAPParams* umap_params = new UMAPParams() - umap_params.n_neighbors = cls.n_neighbors - umap_params.n_components = cls.n_components - umap_params.n_epochs = cls.n_epochs if cls.n_epochs else 0 - umap_params.learning_rate = cls.learning_rate - umap_params.min_dist = cls.min_dist - umap_params.spread = cls.spread - umap_params.set_op_mix_ratio = cls.set_op_mix_ratio - umap_params.local_connectivity = cls.local_connectivity - umap_params.repulsion_strength = cls.repulsion_strength - umap_params.negative_sample_rate = cls.negative_sample_rate - umap_params.transform_queue_size = cls.transform_queue_size - umap_params.verbosity = cls.verbose - umap_params.a = cls.a - umap_params.b = cls.b - if cls.init == "spectral": + umap_params.n_neighbors = self.n_neighbors + umap_params.n_components = self.n_components + umap_params.n_epochs = self.n_epochs if self.n_epochs else 0 + umap_params.learning_rate = self.learning_rate + umap_params.min_dist = self.min_dist + umap_params.spread = self.spread + umap_params.set_op_mix_ratio = self.set_op_mix_ratio + umap_params.local_connectivity = self.local_connectivity + umap_params.repulsion_strength = self.repulsion_strength + umap_params.negative_sample_rate = self.negative_sample_rate + umap_params.transform_queue_size = self.transform_queue_size + umap_params.verbosity = self.verbose + umap_params.a = self.a + umap_params.b = self.b + if self.init == "spectral": umap_params.init = 1 else: # self.init == "random" umap_params.init = 0 - umap_params.target_n_neighbors = cls.target_n_neighbors - if cls.target_metric == "euclidean": + umap_params.target_n_neighbors = self.target_n_neighbors + if self.target_metric == "euclidean": umap_params.target_metric = MetricType.EUCLIDEAN else: # self.target_metric == "categorical" umap_params.target_metric = MetricType.CATEGORICAL - if cls.build_algo == "brute_force_knn": + if self.build_algo == "brute_force_knn": umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN else: # self.init == "nn_descent" umap_params.build_algo = graph_build_algo.NN_DESCENT - if cls.build_kwds is None: + if self.build_kwds is None: umap_params.nn_descent_params.graph_degree = 64 umap_params.nn_descent_params.intermediate_graph_degree = 128 umap_params.nn_descent_params.max_iterations = 20 @@ -506,38 +507,38 @@ class UMAP(UniversalBase, umap_params.nn_descent_params.return_distances = True umap_params.nn_descent_params.n_clusters = 1 else: - umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) - umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) - umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) - umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", True) - if cls.build_kwds.get("nnd_n_clusters", 1) < 1: + umap_params.nn_descent_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) + umap_params.nn_descent_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) + umap_params.nn_descent_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) + umap_params.nn_descent_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) + umap_params.nn_descent_params.return_distances = self.build_kwds.get("nnd_return_distances", True) + if self.build_kwds.get("nnd_n_clusters", 1) < 1: logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1") - umap_params.nn_descent_params.n_clusters = cls.build_kwds.get("nnd_n_clusters", 1) + umap_params.nn_descent_params.n_clusters = self.build_kwds.get("nnd_n_clusters", 1) - umap_params.target_weight = cls.target_weight - umap_params.random_state = cls.random_state - umap_params.deterministic = cls.deterministic + umap_params.target_weight = self.target_weight + umap_params.random_state = check_random_seed(self.random_state) + umap_params.deterministic = self.deterministic try: - umap_params.metric = metric_parsing[cls.metric.lower()] + umap_params.metric = metric_parsing[self.metric.lower()] if sparse: if umap_params.metric not in SPARSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.") elif umap_params.metric not in DENSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.") except KeyError: - raise ValueError(f"Invalid value for metric: {cls.metric}") + raise ValueError(f"Invalid value for metric: {self.metric}") - if cls.metric_kwds is None: + if self.metric_kwds is None: umap_params.p = 2.0 else: - umap_params.p = cls.metric_kwds.get('p') + umap_params.p = self.metric_kwds.get('p') cdef uintptr_t callback_ptr = 0 - if cls.callback: - callback_ptr = cls.callback.get_native_callback() + if self.callback: + callback_ptr = self.callback.get_native_callback() umap_params.callback = callback_ptr return umap_params @@ -672,7 +673,7 @@ class UMAP(UniversalBase, self.handle.getHandle() fss_graph = GraphHolder.new_graph(handle_.get_stream()) cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, + self._build_umap_params( self.sparse_fit) if self.sparse_fit: fit_sparse(handle_[0], @@ -831,7 +832,7 @@ class UMAP(UniversalBase, IF GPUBUILD == 1: cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, + self._build_umap_params( self.sparse_fit) cdef handle_t * handle_ = \ self.handle.getHandle() diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_classifier.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_classifier.py new file mode 100644 index 0000000000..188aceacf4 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_classifier.py @@ -0,0 +1,210 @@ +# +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score + + +@pytest.fixture(scope="module") +def classification_data(): + # Create a synthetic classification dataset. + X, y = make_classification( + n_samples=300, + n_features=20, + n_informative=10, + n_redundant=5, + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize("n_estimators", [10, 50, 100]) +def test_rf_n_estimators(classification_data, n_estimators): + X, y = classification_data + clf = RandomForestClassifier(n_estimators=n_estimators, random_state=42) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("criterion", ["gini", "entropy"]) +def test_rf_criterion(classification_data, criterion): + X, y = classification_data + clf = RandomForestClassifier( + criterion=criterion, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("max_depth", [None, 5, 10]) +def test_rf_max_depth(classification_data, max_depth): + X, y = classification_data + clf = RandomForestClassifier( + max_depth=max_depth, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("min_samples_split", [2, 5, 10]) +def test_rf_min_samples_split(classification_data, min_samples_split): + X, y = classification_data + clf = RandomForestClassifier( + min_samples_split=min_samples_split, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("min_samples_leaf", [1, 2, 4]) +def test_rf_min_samples_leaf(classification_data, min_samples_leaf): + X, y = classification_data + clf = RandomForestClassifier( + min_samples_leaf=min_samples_leaf, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("min_weight_fraction_leaf", [0.0, 0.1]) +def test_rf_min_weight_fraction_leaf( + classification_data, min_weight_fraction_leaf +): + X, y = classification_data + clf = RandomForestClassifier( + min_weight_fraction_leaf=min_weight_fraction_leaf, + n_estimators=50, + random_state=42, + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("max_features", ["sqrt", "log2", 0.5, 5]) +def test_rf_max_features(classification_data, max_features): + X, y = classification_data + clf = RandomForestClassifier( + max_features=max_features, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("max_leaf_nodes", [None, 10, 20]) +def test_rf_max_leaf_nodes(classification_data, max_leaf_nodes): + X, y = classification_data + clf = RandomForestClassifier( + max_leaf_nodes=max_leaf_nodes, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("min_impurity_decrease", [0.0, 0.1]) +def test_rf_min_impurity_decrease(classification_data, min_impurity_decrease): + X, y = classification_data + clf = RandomForestClassifier( + min_impurity_decrease=min_impurity_decrease, + n_estimators=50, + random_state=42, + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("bootstrap", [True, False]) +def test_rf_bootstrap(classification_data, bootstrap): + X, y = classification_data + clf = RandomForestClassifier( + bootstrap=bootstrap, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("n_jobs", [1, -1]) +def test_rf_n_jobs(classification_data, n_jobs): + X, y = classification_data + clf = RandomForestClassifier( + n_jobs=n_jobs, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("verbose", [0, 1]) +def test_rf_verbose(classification_data, verbose): + X, y = classification_data + clf = RandomForestClassifier( + verbose=verbose, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("warm_start", [False, True]) +def test_rf_warm_start(classification_data, warm_start): + X, y = classification_data + clf = RandomForestClassifier( + warm_start=warm_start, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("class_weight", [None, "balanced", {0: 1, 1: 2}]) +def test_rf_class_weight(classification_data, class_weight): + X, y = classification_data + clf = RandomForestClassifier( + class_weight=class_weight, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("ccp_alpha", [0.0, 0.1]) +def test_rf_ccp_alpha(classification_data, ccp_alpha): + X, y = classification_data + clf = RandomForestClassifier( + ccp_alpha=ccp_alpha, n_estimators=50, random_state=42 + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +@pytest.mark.parametrize("max_samples", [None, 0.8, 50]) +def test_rf_max_samples(classification_data, max_samples): + X, y = classification_data + clf = RandomForestClassifier( + max_samples=max_samples, + bootstrap=True, + n_estimators=50, + random_state=42, + ) + clf.fit(X, y) + _ = accuracy_score(y, clf.predict(X)) + + +def test_rf_random_state(classification_data): + X, y = classification_data + clf1 = RandomForestClassifier(n_estimators=50, random_state=42).fit(X, y) + clf2 = RandomForestClassifier(n_estimators=50, random_state=42).fit(X, y) + # Predictions should be identical with the same random_state. + assert np.array_equal(clf1.predict(X), clf2.predict(X)) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_regressor.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_regressor.py new file mode 100644 index 0000000000..60d61e7a79 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_rf_regressor.py @@ -0,0 +1,191 @@ +# +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import r2_score + + +@pytest.fixture(scope="module") +def regression_data(): + # Create a synthetic regression dataset. + X, y = make_regression( + n_samples=300, + n_features=20, + n_informative=10, + noise=0.1, + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize("n_estimators", [10, 50, 100]) +def test_rf_n_estimators_reg(regression_data, n_estimators): + X, y = regression_data + reg = RandomForestRegressor(n_estimators=n_estimators, random_state=42) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("criterion", ["squared_error", "absolute_error"]) +def test_rf_criterion_reg(regression_data, criterion): + X, y = regression_data + reg = RandomForestRegressor( + criterion=criterion, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("max_depth", [None, 5, 10]) +def test_rf_max_depth_reg(regression_data, max_depth): + X, y = regression_data + reg = RandomForestRegressor( + max_depth=max_depth, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("min_samples_split", [2, 5, 10]) +def test_rf_min_samples_split_reg(regression_data, min_samples_split): + X, y = regression_data + reg = RandomForestRegressor( + min_samples_split=min_samples_split, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("min_samples_leaf", [1, 2, 4]) +def test_rf_min_samples_leaf_reg(regression_data, min_samples_leaf): + X, y = regression_data + reg = RandomForestRegressor( + min_samples_leaf=min_samples_leaf, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("min_weight_fraction_leaf", [0.0, 0.1]) +def test_rf_min_weight_fraction_leaf_reg( + regression_data, min_weight_fraction_leaf +): + X, y = regression_data + reg = RandomForestRegressor( + min_weight_fraction_leaf=min_weight_fraction_leaf, + n_estimators=50, + random_state=42, + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("max_features", ["sqrt", "log2", 0.5, 5]) +def test_rf_max_features_reg(regression_data, max_features): + X, y = regression_data + reg = RandomForestRegressor( + max_features=max_features, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("max_leaf_nodes", [None, 10, 20]) +def test_rf_max_leaf_nodes_reg(regression_data, max_leaf_nodes): + X, y = regression_data + reg = RandomForestRegressor( + max_leaf_nodes=max_leaf_nodes, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("min_impurity_decrease", [0.0, 0.1]) +def test_rf_min_impurity_decrease_reg(regression_data, min_impurity_decrease): + X, y = regression_data + reg = RandomForestRegressor( + min_impurity_decrease=min_impurity_decrease, + n_estimators=50, + random_state=42, + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("bootstrap", [True, False]) +def test_rf_bootstrap_reg(regression_data, bootstrap): + X, y = regression_data + reg = RandomForestRegressor( + bootstrap=bootstrap, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("n_jobs", [1, -1]) +def test_rf_n_jobs_reg(regression_data, n_jobs): + X, y = regression_data + reg = RandomForestRegressor( + n_jobs=n_jobs, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("verbose", [0, 1]) +def test_rf_verbose_reg(regression_data, verbose): + X, y = regression_data + reg = RandomForestRegressor( + verbose=verbose, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("warm_start", [False, True]) +def test_rf_warm_start_reg(regression_data, warm_start): + X, y = regression_data + reg = RandomForestRegressor( + warm_start=warm_start, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("ccp_alpha", [0.0, 0.1]) +def test_rf_ccp_alpha_reg(regression_data, ccp_alpha): + X, y = regression_data + reg = RandomForestRegressor( + ccp_alpha=ccp_alpha, n_estimators=50, random_state=42 + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) + + +@pytest.mark.parametrize("max_samples", [None, 0.8, 50]) +def test_rf_max_samples_reg(regression_data, max_samples): + X, y = regression_data + reg = RandomForestRegressor( + max_samples=max_samples, + bootstrap=True, + n_estimators=50, + random_state=42, + ) + reg.fit(X, y) + _ = r2_score(y, reg.predict(X)) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py index 543d545caf..67ace7c3fd 100644 --- a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the “License”); # you may not use this file except in compliance with the License. @@ -46,10 +46,33 @@ def test_umap_min_dist(manifold_data, min_dist): @pytest.mark.parametrize( - "metric", ["euclidean", "manhattan", "chebyshev", "cosine"] + "metric", + [ + "euclidean", + "manhattan", + "chebyshev", + "cosine", + # These metrics are currently not supported in cuml, + # we test them here to make sure no exception is raised + "sokalsneath", + "rogerstanimoto", + "sokalmichener", + "yule", + "ll_dirichlet", + "russellrao", + "kulsinski", + "dice", + "wminkowski", + "mahalanobis", + "haversine", + ], ) def test_umap_metric(manifold_data, metric): X = manifold_data + # haversine only works for 2D data + if metric == "haversine": + X = X[:, :2] + umap = UMAP(metric=metric, random_state=42) X_embedded = umap.fit_transform(X) trust = trustworthiness(X, X_embedded, n_neighbors=5) diff --git a/python/cuml/cuml/tests/test_common.py b/python/cuml/cuml/tests/test_common.py new file mode 100644 index 0000000000..0dc6ccd575 --- /dev/null +++ b/python/cuml/cuml/tests/test_common.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import numpy as np +import cuml +from cuml.datasets import make_blobs + + +@pytest.mark.parametrize( + "Estimator", + [ + cuml.KMeans, + cuml.RandomForestRegressor, + cuml.RandomForestClassifier, + cuml.TSNE, + cuml.UMAP, + ], +) +def test_random_state_argument(Estimator): + X, y = make_blobs(random_state=0) + # Check that both integer and np.random.RandomState are accepted + for seed in (42, np.random.RandomState(42)): + est = Estimator(random_state=seed) + + if est.__class__.__name__ != "TSNE": + est.fit(X, y) + else: + est.fit(X) diff --git a/python/cuml/cuml/tests/test_device_selection.py b/python/cuml/cuml/tests/test_device_selection.py index 31c0f9aed6..9d37bf8553 100644 --- a/python/cuml/cuml/tests/test_device_selection.py +++ b/python/cuml/cuml/tests/test_device_selection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ from cuml.decomposition import PCA, TruncatedSVD from cuml.cluster import KMeans from cuml.cluster import DBSCAN +from cuml.ensemble import RandomForestClassifier, RandomForestRegressor from cuml.common.device_selection import DeviceType, using_device_type from cuml.testing.utils import assert_dbscan_equal from hdbscan import HDBSCAN as refHDBSCAN @@ -50,8 +51,11 @@ from sklearn.decomposition import TruncatedSVD as skTruncatedSVD from sklearn.cluster import KMeans as skKMeans from sklearn.cluster import DBSCAN as skDBSCAN +from sklearn.ensemble import RandomForestClassifier as skRFC +from sklearn.ensemble import RandomForestRegressor as skRFR from sklearn.datasets import make_regression, make_blobs from sklearn.manifold import TSNE as refTSNE +from sklearn.metrics import accuracy_score, r2_score from pytest_cases import fixture_union, fixture from importlib import import_module import inspect @@ -135,11 +139,12 @@ def make_reg_dataset(): n_samples=2000, n_features=20, n_informative=18, random_state=0 ) X_train, X_test = X[:1800], X[1800:] - y_train, _ = y[:1800], y[1800:] + y_train, y_test = y[:1800], y[1800:] return ( X_train.astype(np.float32), y_train.astype(np.float32), X_test.astype(np.float32), + y_test.astype(np.float32), ) @@ -152,16 +157,17 @@ def make_blob_dataset(): cluster_std=1.0, ) X_train, X_test = X[:1800], X[1800:] - y_train, _ = y[:1800], y[1800:] + y_train, y_test = y[:1800], y[1800:] return ( X_train.astype(np.float32), y_train.astype(np.float32), X_test.astype(np.float32), + y_test.astype(np.float32), ) -X_train_reg, y_train_reg, X_test_reg = make_reg_dataset() -X_train_blob, y_train_blob, X_test_blob = make_blob_dataset() +X_train_reg, y_train_reg, X_test_reg, y_test_reg = make_reg_dataset() +X_train_blob, y_train_blob, X_test_blob, y_test_blob = make_blob_dataset() def check_trustworthiness(cuml_embedding, test_data): @@ -977,11 +983,11 @@ def test_hdbscan_methods(train_device, infer_device): @pytest.mark.parametrize("infer_device", ["cpu", "gpu"]) def test_kmeans_methods(train_device, infer_device): n_clusters = 20 - ref_model = skKMeans(n_clusters=n_clusters) + ref_model = skKMeans(n_clusters=n_clusters, random_state=42) ref_model.fit(X_train_blob) ref_output = ref_model.predict(X_test_blob) - model = KMeans(n_clusters=n_clusters) + model = KMeans(n_clusters=n_clusters, random_state=42) with using_device_type(train_device): model.fit(X_train_blob) with using_device_type(infer_device): @@ -1011,3 +1017,67 @@ def test_dbscan_methods(train_device, infer_device): assert_dbscan_equal( ref_output, output, X_train_blob, model.core_sample_indices_, eps ) + + +@pytest.mark.parametrize("train_device", ["cpu", "gpu"]) +@pytest.mark.parametrize("infer_device", ["cpu", "gpu"]) +def test_random_forest_regressor(train_device, infer_device): + ref_model = skRFR( + n_estimators=40, + max_depth=16, + min_samples_split=2, + max_features=1.0, + random_state=10, + ) + model = RandomForestRegressor( + max_features=1.0, + max_depth=16, + n_bins=64, + n_estimators=40, + n_streams=1, + random_state=10, + ) + ref_model.fit(X_train_reg, y_train_reg) + ref_output = ref_model.predict(X_test_reg) + + with using_device_type(train_device): + model.fit(X_train_reg, y_train_reg) + with using_device_type(infer_device): + output = model.predict(X_test_reg) + + cuml_acc = r2_score(y_test_reg, output) + sk_acc = r2_score(y_test_reg, ref_output) + + assert np.abs(cuml_acc - sk_acc) <= 0.05 + + +@pytest.mark.parametrize("train_device", ["cpu", "gpu"]) +@pytest.mark.parametrize("infer_device", ["cpu", "gpu"]) +def test_random_forest_classifier(train_device, infer_device): + ref_model = skRFC( + n_estimators=40, + max_depth=16, + min_samples_split=2, + max_features=1.0, + random_state=10, + ) + model = RandomForestClassifier( + max_features=1.0, + max_depth=16, + n_bins=64, + n_estimators=40, + n_streams=1, + random_state=10, + ) + ref_model.fit(X_train_blob, y_train_blob) + ref_output = ref_model.predict(X_test_blob) + + with using_device_type(train_device): + model.fit(X_train_blob, y_train_blob) + with using_device_type(infer_device): + output = model.predict(X_test_blob) + + cuml_acc = accuracy_score(y_test_blob, output) + ref_acc = accuracy_score(y_test_blob, ref_output) + + assert cuml_acc == ref_acc