Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'branch-25.04' into nvks-runners
Browse files Browse the repository at this point in the history
jameslamb authored Jan 31, 2025
2 parents 71d7e07 + 1e0e147 commit 6bfb1d0
Showing 2 changed files with 4 additions and 11 deletions.
8 changes: 4 additions & 4 deletions python/cuml/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -263,7 +263,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
cdef float objective32
cdef float objective64
cdef double objective64
cdef int num_iters

cdef vector[float] c_classes_
@@ -387,7 +387,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
qnpams,
<bool> self.standardization,
<int> self._num_classes,
<double*> &objective32,
<double*> &objective64,
<int*> &num_iters)
else:
assert self.index_dtype == np.int64, f"unsupported index dtype: {self.index_dtype}"
@@ -403,7 +403,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
qnpams,
<bool> self.standardization,
<int> self._num_classes,
<double*> &objective32,
<double*> &objective64,
<int*> &num_iters)

self.solver_model.objective = objective64
7 changes: 0 additions & 7 deletions python/cuml/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -34,13 +34,6 @@
dask_cudf = gpu_only_import("dask_cudf")
cudf = gpu_only_import("cudf")

pytestmark = [
pytest.mark.mg,
pytest.mark.skip(
reason="pytest hang https://github.com/rapidsai/cuml/issues/6247"
),
]


def _prep_training_data(c, X_train, y_train, partitions_per_worker):
workers = c.has_what().keys()

0 comments on commit 6bfb1d0

Please sign in to comment.