diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index 6be09f6912..33a84d878d 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ np = cpu_only_import('numpy') from cuml.internals.safe_imports import gpu_only_import rmm = gpu_only_import('rmm') from cuml.internals.safe_imports import safe_import_from, return_false +from cuml.internals.utils import check_random_seed import typing IF GPUBUILD == 1: @@ -209,8 +210,11 @@ class KMeans(UniversalBase, params.init = self._params_init params.max_iter = self.max_iter params.tol = self.tol + # After transferring from one device to another `_seed` might not be set + # so we need to pass a dummy value here. Its value does not matter as the + # seed is only used during fitting + params.rng_state.seed = getattr(self, "_seed", 0) params.verbosity = (self.verbose) - params.rng_state.seed = self.random_state params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 params.batch_samples = self.max_samples_per_batch params.oversampling_factor = self.oversampling_factor @@ -307,6 +311,7 @@ class KMeans(UniversalBase, else None), check_dtype=check_dtype) + self._seed = check_random_seed(self.random_state) self.feature_names_in_ = _X_m.index IF GPUBUILD == 1: diff --git a/python/cuml/cuml/cluster/kmeans_mg.pyx b/python/cuml/cuml/cluster/kmeans_mg.pyx index cf0a2967c4..1294467755 100644 --- a/python/cuml/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cuml/cluster/kmeans_mg.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ from cuml.common import input_to_cuml_array from cuml.cluster import KMeans from cuml.cluster.kmeans_utils cimport params as KMeansParams +from cuml.internals.utils import check_random_seed cdef extern from "cuml/cluster/kmeans_mg.hpp" \ @@ -129,6 +130,8 @@ class KMeansMG(KMeans): cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr + self._seed = check_random_seed(self.random_state) + if (self.init in ['scalable-k-means++', 'k-means||', 'random']): self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters, self.n_cols), diff --git a/python/cuml/cuml/decomposition/pca.pyx b/python/cuml/cuml/decomposition/pca.pyx index db2f0f62c8..9a055e6505 100644 --- a/python/cuml/cuml/decomposition/pca.pyx +++ b/python/cuml/cuml/decomposition/pca.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,9 +209,6 @@ class PCA(UniversalBase, ``n_components = min(n_samples, n_features)`` - random_state : int / None (default = None) - If you want results to be the same when you restart Python, select a - state. svd_solver : 'full' or 'jacobi' or 'auto' (default = 'full') Full uses a eigendecomposition of the covariance matrix then discards components. @@ -292,7 +289,7 @@ class PCA(UniversalBase, @device_interop_preparation def __init__(self, *, copy=True, handle=None, iterated_power=15, - n_components=None, random_state=None, svd_solver='auto', + n_components=None, svd_solver='auto', tol=1e-7, verbose=False, whiten=False, output_type=None): # parameters @@ -302,7 +299,6 @@ class PCA(UniversalBase, self.copy = copy self.iterated_power = iterated_power self.n_components = n_components - self.random_state = random_state self.svd_solver = svd_solver self.tol = tol self.whiten = whiten @@ -739,7 +735,7 @@ class PCA(UniversalBase, def _get_param_names(cls): return super()._get_param_names() + \ ["copy", "iterated_power", "n_components", "svd_solver", "tol", - "whiten", "random_state"] + "whiten"] def _check_is_fitted(self, attr): if not hasattr(self, attr) or (getattr(self, attr) is None): diff --git a/python/cuml/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/cuml/ensemble/randomforestclassifier.pyx index 5198d60b28..aa0bea6f0c 100644 --- a/python/cuml/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/cuml/ensemble/randomforestclassifier.pyx @@ -33,6 +33,7 @@ import cuml.internals from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array +from cuml.internals.utils import check_random_seed from cuml.internals.logger cimport level_enum from cuml.ensemble.randomforest_common import BaseRandomForestModel @@ -451,7 +452,7 @@ class RandomForestClassifier(BaseRandomForestModel, if self.random_state is None: seed_val = NULL else: - seed_val = self.random_state + seed_val = check_random_seed(self.random_state) rf_params = set_rf_params( self.max_depth, self.max_leaves, diff --git a/python/cuml/cuml/ensemble/randomforestregressor.pyx b/python/cuml/cuml/ensemble/randomforestregressor.pyx index 6e3a13d0fb..c30849d566 100644 --- a/python/cuml/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/cuml/ensemble/randomforestregressor.pyx @@ -34,6 +34,7 @@ from cuml.internals.logger cimport level_enum from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array +from cuml.internals.utils import check_random_seed from cuml.ensemble.randomforest_common import BaseRandomForestModel from cuml.ensemble.randomforest_common import _obtain_fil_model @@ -438,7 +439,7 @@ class RandomForestRegressor(BaseRandomForestModel, if self.random_state is None: seed_val = NULL else: - seed_val = self.random_state + seed_val = check_random_seed(self.random_state) rf_params = set_rf_params( self.max_depth, self.max_leaves, diff --git a/python/cuml/cuml/internals/utils.py b/python/cuml/cuml/internals/utils.py new file mode 100644 index 0000000000..67b2586e26 --- /dev/null +++ b/python/cuml/cuml/internals/utils.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numbers +import numpy as np + + +def check_random_seed(seed): + """Turn a np.random.RandomState instance into a seed. + Parameters + ---------- + seed : None | int | instance of RandomState + If seed is None, return a random int as seed. + If seed is an int, return it. + If seed is a RandomState instance, derive a seed from it. + Otherwise raise ValueError. + """ + if seed is None: + seed = np.random.RandomState(None) + + if isinstance(seed, numbers.Integral): + return seed + if isinstance(seed, np.random.RandomState): + return seed.randint( + low=0, high=np.iinfo(np.uint32).max, dtype=np.uint32 + ) + raise ValueError("%r cannot be used to create a seed." % seed) diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index 7ff8702a2c..d691b8e923 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -31,10 +31,10 @@ from cuml.internals.base import UniversalBase from pylibraft.common.handle cimport handle_t from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop +from cuml.internals.utils import check_random_seed from cuml.internals import logger from cuml.internals cimport logger - from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray from cuml.common.sparse_utils import is_sparse @@ -596,7 +596,7 @@ class TSNE(UniversalBase, def _build_tsne_params(self, algo): cdef long long seed = -1 if self.random_state is not None: - seed = self.random_state + seed = check_random_seed(self.random_state) cdef TSNEParams* params = new TSNEParams() params.dim = self.n_components diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 079b270d0a..0f5e2c95b2 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -47,6 +47,7 @@ from cuml.internals.array_sparse import SparseCumlArray from cuml.internals.mem_type import MemoryType from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.internals.utils import check_random_seed from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.internals.api_decorators import device_interop_preparation @@ -401,22 +402,7 @@ class UMAP(UniversalBase, self.deterministic = random_state is not None - # Check to see if we are already a random_state (type==np.uint64). - # Reuse this if already passed (can happen from get_params() of another - # instance) - if isinstance(random_state, np.uint64): - self.random_state = random_state - else: - # Otherwise create a RandomState instance to generate a new - # np.uint64 - if isinstance(random_state, np.random.RandomState): - rs = random_state - else: - rs = np.random.RandomState(random_state) - - self.random_state = rs.randint(low=0, - high=np.iinfo(np.uint32).max, - dtype=np.uint32) + self.random_state = random_state if target_metric == "euclidean" or target_metric == "categorical": self.target_metric = target_metric @@ -467,38 +453,37 @@ class UMAP(UniversalBase, if self.min_dist > self.spread: raise ValueError("min_dist should be <= spread") - @staticmethod - def _build_umap_params(cls, sparse): + def _build_umap_params(self, sparse): IF GPUBUILD == 1: cdef UMAPParams* umap_params = new UMAPParams() - umap_params.n_neighbors = cls.n_neighbors - umap_params.n_components = cls.n_components - umap_params.n_epochs = cls.n_epochs if cls.n_epochs else 0 - umap_params.learning_rate = cls.learning_rate - umap_params.min_dist = cls.min_dist - umap_params.spread = cls.spread - umap_params.set_op_mix_ratio = cls.set_op_mix_ratio - umap_params.local_connectivity = cls.local_connectivity - umap_params.repulsion_strength = cls.repulsion_strength - umap_params.negative_sample_rate = cls.negative_sample_rate - umap_params.transform_queue_size = cls.transform_queue_size - umap_params.verbosity = cls.verbose - umap_params.a = cls.a - umap_params.b = cls.b - if cls.init == "spectral": + umap_params.n_neighbors = self.n_neighbors + umap_params.n_components = self.n_components + umap_params.n_epochs = self.n_epochs if self.n_epochs else 0 + umap_params.learning_rate = self.learning_rate + umap_params.min_dist = self.min_dist + umap_params.spread = self.spread + umap_params.set_op_mix_ratio = self.set_op_mix_ratio + umap_params.local_connectivity = self.local_connectivity + umap_params.repulsion_strength = self.repulsion_strength + umap_params.negative_sample_rate = self.negative_sample_rate + umap_params.transform_queue_size = self.transform_queue_size + umap_params.verbosity = self.verbose + umap_params.a = self.a + umap_params.b = self.b + if self.init == "spectral": umap_params.init = 1 else: # self.init == "random" umap_params.init = 0 - umap_params.target_n_neighbors = cls.target_n_neighbors - if cls.target_metric == "euclidean": + umap_params.target_n_neighbors = self.target_n_neighbors + if self.target_metric == "euclidean": umap_params.target_metric = MetricType.EUCLIDEAN else: # self.target_metric == "categorical" umap_params.target_metric = MetricType.CATEGORICAL - if cls.build_algo == "brute_force_knn": + if self.build_algo == "brute_force_knn": umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN else: # self.init == "nn_descent" umap_params.build_algo = graph_build_algo.NN_DESCENT - if cls.build_kwds is None: + if self.build_kwds is None: umap_params.nn_descent_params.graph_degree = 64 umap_params.nn_descent_params.intermediate_graph_degree = 128 umap_params.nn_descent_params.max_iterations = 20 @@ -506,38 +491,38 @@ class UMAP(UniversalBase, umap_params.nn_descent_params.return_distances = True umap_params.nn_descent_params.n_clusters = 1 else: - umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) - umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) - umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) - umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", True) - if cls.build_kwds.get("nnd_n_clusters", 1) < 1: + umap_params.nn_descent_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) + umap_params.nn_descent_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) + umap_params.nn_descent_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) + umap_params.nn_descent_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) + umap_params.nn_descent_params.return_distances = self.build_kwds.get("nnd_return_distances", True) + if self.build_kwds.get("nnd_n_clusters", 1) < 1: logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1") - umap_params.nn_descent_params.n_clusters = cls.build_kwds.get("nnd_n_clusters", 1) + umap_params.nn_descent_params.n_clusters = self.build_kwds.get("nnd_n_clusters", 1) - umap_params.target_weight = cls.target_weight - umap_params.random_state = cls.random_state - umap_params.deterministic = cls.deterministic + umap_params.target_weight = self.target_weight + umap_params.random_state = check_random_seed(self.random_state) + umap_params.deterministic = self.deterministic try: - umap_params.metric = metric_parsing[cls.metric.lower()] + umap_params.metric = metric_parsing[self.metric.lower()] if sparse: if umap_params.metric not in SPARSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.") elif umap_params.metric not in DENSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.") except KeyError: - raise ValueError(f"Invalid value for metric: {cls.metric}") + raise ValueError(f"Invalid value for metric: {self.metric}") - if cls.metric_kwds is None: + if self.metric_kwds is None: umap_params.p = 2.0 else: - umap_params.p = cls.metric_kwds.get('p') + umap_params.p = self.metric_kwds.get('p') cdef uintptr_t callback_ptr = 0 - if cls.callback: - callback_ptr = cls.callback.get_native_callback() + if self.callback: + callback_ptr = self.callback.get_native_callback() umap_params.callback = callback_ptr return umap_params @@ -672,7 +657,7 @@ class UMAP(UniversalBase, self.handle.getHandle() fss_graph = GraphHolder.new_graph(handle_.get_stream()) cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, + self._build_umap_params( self.sparse_fit) if self.sparse_fit: fit_sparse(handle_[0], @@ -831,7 +816,7 @@ class UMAP(UniversalBase, IF GPUBUILD == 1: cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, + self._build_umap_params( self.sparse_fit) cdef handle_t * handle_ = \ self.handle.getHandle() diff --git a/python/cuml/cuml/tests/test_common.py b/python/cuml/cuml/tests/test_common.py new file mode 100644 index 0000000000..0dc6ccd575 --- /dev/null +++ b/python/cuml/cuml/tests/test_common.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import numpy as np +import cuml +from cuml.datasets import make_blobs + + +@pytest.mark.parametrize( + "Estimator", + [ + cuml.KMeans, + cuml.RandomForestRegressor, + cuml.RandomForestClassifier, + cuml.TSNE, + cuml.UMAP, + ], +) +def test_random_state_argument(Estimator): + X, y = make_blobs(random_state=0) + # Check that both integer and np.random.RandomState are accepted + for seed in (42, np.random.RandomState(42)): + est = Estimator(random_state=seed) + + if est.__class__.__name__ != "TSNE": + est.fit(X, y) + else: + est.fit(X) diff --git a/python/cuml/cuml/tests/test_device_selection.py b/python/cuml/cuml/tests/test_device_selection.py index 31c0f9aed6..1ff67d2891 100644 --- a/python/cuml/cuml/tests/test_device_selection.py +++ b/python/cuml/cuml/tests/test_device_selection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -977,11 +977,11 @@ def test_hdbscan_methods(train_device, infer_device): @pytest.mark.parametrize("infer_device", ["cpu", "gpu"]) def test_kmeans_methods(train_device, infer_device): n_clusters = 20 - ref_model = skKMeans(n_clusters=n_clusters) + ref_model = skKMeans(n_clusters=n_clusters, random_state=42) ref_model.fit(X_train_blob) ref_output = ref_model.predict(X_test_blob) - model = KMeans(n_clusters=n_clusters) + model = KMeans(n_clusters=n_clusters, random_state=42) with using_device_type(train_device): model.fit(X_train_blob) with using_device_type(infer_device):