Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into fix/decrease-umap-logging-verbosity
Browse files Browse the repository at this point in the history
  • Loading branch information
csadorf authored Feb 7, 2025
2 parents 3252fc7 + f60b5f0 commit 823e581
Show file tree
Hide file tree
Showing 21 changed files with 907 additions and 109 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
# Please keep pr-builder as the top job here
pr-builder:
needs:
- check-nightly-ci
- changed-files
- checks
- clang-tidy
Expand Down Expand Up @@ -43,6 +44,18 @@ jobs:
- name: Telemetry setup
if: ${{ vars.TELEMETRY_ENABLED == 'true' }}
uses: rapidsai/shared-actions/telemetry-dispatch-stash-base-env-vars@main
check-nightly-ci:
# Switch to ubuntu-latest once it defaults to a version of Ubuntu that
# provides at least Python 3.11 (see
# https://docs.python.org/3/library/datetime.html#datetime.date.fromisoformat)
runs-on: ubuntu-24.04
env:
RAPIDS_GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Check if nightly CI is passing
uses: rapidsai/shared-actions/check_nightly_success/dispatch@main
with:
repo: cuml
changed-files:
secrets: inherit
needs: telemetry-setup
Expand Down
2 changes: 1 addition & 1 deletion ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cd "${package_dir}"
sccache --zero-stats

rapids-logger "Building '${package_name}' wheel"
python -m pip wheel \
rapids-pip-retry wheel \
-w dist \
-v \
--no-deps \
Expand Down
2 changes: 1 addition & 1 deletion ci/build_wheel_libcuml.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rapids-dependency-file-generator \
| tee /tmp/requirements-build.txt

rapids-logger "Installing build requirements"
python -m pip install \
rapids-pip-retry install \
-v \
--prefer-binary \
-r /tmp/requirements-build.txt
Expand Down
2 changes: 1 addition & 1 deletion ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ RAPIDS_PY_WHEEL_NAME="cuml_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from
RAPIDS_PY_WHEEL_NAME="libcuml_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 cpp ./dist

# echo to expand wildcard before adding `[extra]` requires for pip
python -m pip install \
rapids-pip-retry install \
./dist/libcuml*.whl \
"$(echo ./dist/cuml*.whl)[test]"

Expand Down
45 changes: 28 additions & 17 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,42 @@ inline void launcher(const raft::handle_t& handle,
// TODO: use nndescent from cuvs
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");
RAFT_EXPECTS(params->nn_descent_params.return_distances,
"return_distances for nn descent should be set to true to be used for UMAP");

auto graph = get_graph_nnd(handle, inputsA, params);

auto indices_d = raft::make_device_matrix<int64_t, int64_t>(
handle, inputsA.n, params->nn_descent_params.graph_degree);

raft::copy(indices_d.data_handle(),
graph.graph().data_handle(),
inputsA.n * params->nn_descent_params.graph_degree,
stream);

// `graph.graph()` is a host array (n x graph_degree).
// Slice and copy to a temporary host array (n x n_neighbors), then copy
// that to the output device array `out.knn_indices` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so the temporary host array and
// slice isn't necessary.
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
size_t graph_degree = params->nn_descent_params.graph_degree;
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
for (int j = 0; j < n_neighbors; j++) {
auto target = temp_indices_h.data_handle();
auto source = graph.graph().data_handle();
target[i * n_neighbors + j] = source[i * graph_degree + j];
}
}
raft::copy(handle,
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
temp_indices_h.view());

// `graph.distances()` is a device array (n x graph_degree).
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(inputsA.n),
static_cast<int64_t>(n_neighbors)};

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for UMAP");
auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors);
raft::matrix::slice<float, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords);
auto out_knn_indices_view =
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors);
raft::matrix::slice<int64_t, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords);
handle,
raft::make_const_mdspan(graph.distances().value()),
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
coords);
}
}

Expand Down
7 changes: 6 additions & 1 deletion python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ np = cpu_only_import('numpy')
from cuml.internals.safe_imports import gpu_only_import
rmm = gpu_only_import('rmm')
from cuml.internals.safe_imports import safe_import_from, return_false
from cuml.internals.utils import check_random_seed
import typing

IF GPUBUILD == 1:
Expand Down Expand Up @@ -209,8 +210,11 @@ class KMeans(UniversalBase,
params.init = self._params_init
params.max_iter = <int>self.max_iter
params.tol = <double>self.tol
# After transferring from one device to another `_seed` might not be set
# so we need to pass a dummy value here. Its value does not matter as the
# seed is only used during fitting
params.rng_state.seed = <int>getattr(self, "_seed", 0)
params.verbosity = <raft_level_enum>(<int>self.verbose)
params.rng_state.seed = self.random_state
params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501
params.batch_samples = <int>self.max_samples_per_batch
params.oversampling_factor = <double>self.oversampling_factor
Expand Down Expand Up @@ -307,6 +311,7 @@ class KMeans(UniversalBase,
else None),
check_dtype=check_dtype)

self._seed = check_random_seed(self.random_state)
self.feature_names_in_ = _X_m.index

IF GPUBUILD == 1:
Expand Down
5 changes: 4 additions & 1 deletion python/cuml/cuml/cluster/kmeans_mg.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,6 +32,7 @@ from cuml.common import input_to_cuml_array

from cuml.cluster import KMeans
from cuml.cluster.kmeans_utils cimport params as KMeansParams
from cuml.internals.utils import check_random_seed


cdef extern from "cuml/cluster/kmeans_mg.hpp" \
Expand Down Expand Up @@ -129,6 +130,8 @@ class KMeansMG(KMeans):

cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr

self._seed = check_random_seed(self.random_state)

if (self.init in ['scalable-k-means++', 'k-means||', 'random']):
self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters,
self.n_cols),
Expand Down
10 changes: 3 additions & 7 deletions python/cuml/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -209,9 +209,6 @@ class PCA(UniversalBase,
``n_components = min(n_samples, n_features)``
random_state : int / None (default = None)
If you want results to be the same when you restart Python, select a
state.
svd_solver : 'full' or 'jacobi' or 'auto' (default = 'full')
Full uses a eigendecomposition of the covariance matrix then discards
components.
Expand Down Expand Up @@ -292,7 +289,7 @@ class PCA(UniversalBase,

@device_interop_preparation
def __init__(self, *, copy=True, handle=None, iterated_power=15,
n_components=None, random_state=None, svd_solver='auto',
n_components=None, svd_solver='auto',
tol=1e-7, verbose=False, whiten=False,
output_type=None):
# parameters
Expand All @@ -302,7 +299,6 @@ class PCA(UniversalBase,
self.copy = copy
self.iterated_power = iterated_power
self.n_components = n_components
self.random_state = random_state
self.svd_solver = svd_solver
self.tol = tol
self.whiten = whiten
Expand Down Expand Up @@ -739,7 +735,7 @@ class PCA(UniversalBase,
def _get_param_names(cls):
return super()._get_param_names() + \
["copy", "iterated_power", "n_components", "svd_solver", "tol",
"whiten", "random_state"]
"whiten"]

def _check_is_fitted(self, attr):
if not hasattr(self, attr) or (getattr(self, attr) is None):
Expand Down
35 changes: 30 additions & 5 deletions python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import treelite.sklearn
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.global_settings import GlobalSettings

cp = gpu_only_import('cupy')
import math
import warnings
Expand All @@ -24,7 +29,7 @@ np = cpu_only_import('numpy')
from cuml import ForestInference
from cuml.fil.fil import TreeliteModel
from pylibraft.common.handle import Handle
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.internals.array import CumlArray
from cuml.common.exceptions import NotFittedError
import cuml.internals
Expand All @@ -39,7 +44,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.prims.label.classlabels import make_monotonic, check_labels


class BaseRandomForestModel(Base):
class BaseRandomForestModel(UniversalBase):
_param_names = ['n_estimators', 'max_depth', 'handle',
'max_features', 'n_bins',
'split_criterion', 'min_samples_leaf',
Expand Down Expand Up @@ -67,6 +72,7 @@ class BaseRandomForestModel(Base):

classes_ = CumlArrayDescriptor()

@device_interop_preparation
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='sqrt', n_bins=128,
bootstrap=True,
Expand All @@ -88,7 +94,7 @@ class BaseRandomForestModel(Base):
"class_weight": class_weight}

for key, vals in sklearn_params.items():
if vals:
if vals and not GlobalSettings().accelerator_active:
raise TypeError(
" The Scikit-learn variable ", key,
" is not supported in cuML,"
Expand All @@ -97,7 +103,7 @@ class BaseRandomForestModel(Base):
"api.html#random-forest) for more information")

for key in kwargs.keys():
if key not in self._param_names:
if key not in self._param_names and not GlobalSettings().accelerator_active:
raise TypeError(
" The variable ", key,
" is not supported in cuML,"
Expand Down Expand Up @@ -154,6 +160,7 @@ class BaseRandomForestModel(Base):
self.model_pbuf_bytes = bytearray()
self.treelite_handle = None
self.treelite_serialized_model = None
self._cpu_model_class_lock = threading.RLock()

def _get_max_feat_val(self) -> float:
if isinstance(self.max_features, int):
Expand Down Expand Up @@ -268,6 +275,24 @@ class BaseRandomForestModel(Base):
self.treelite_handle = <uintptr_t> tl_handle
return self.treelite_handle

def cpu_to_gpu(self):
tl_model = treelite.sklearn.import_model(self._cpu_model)
self._temp = TreeliteModel.from_treelite_bytes(tl_model.serialize_bytes())
self.treelite_serialized_model = treelite_serialize(self._temp.handle)
self._obtain_treelite_handle()
self.dtype = np.float64
self.update_labels = False
super().cpu_to_gpu()

def gpu_to_cpu(self):
self._obtain_treelite_handle()
tl_model = TreeliteModel.from_treelite_model_handle(
self.treelite_handle,
take_handle_ownership=False)
tl_bytes = tl_model.to_treelite_bytes()
tl_model2 = treelite.Model.deserialize_bytes(tl_bytes)
self._cpu_model = treelite.sklearn.export_model(tl_model2)

@cuml.internals.api_base_return_generic(set_output_type=True,
set_n_features_in=True,
get_output_type=False)
Expand Down
Loading

0 comments on commit 823e581

Please sign in to comment.