Skip to content

Commit

Permalink
Merge pull request #5674 from rapidsai/branch-23.12
Browse files Browse the repository at this point in the history
Forward-merge branch-23.12 to branch-24.02
  • Loading branch information
GPUtester authored Nov 29, 2023
2 parents 72bde03 + 97b6fa3 commit 529866b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,

switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
ASSERT(C == 2, "qn_mg.cuh: logistic loss invalid C");
ASSERT(C > 0, "qn_mg.cuh: logistic loss invalid C");
ML::GLM::detail::LogisticLoss<T> loss(handle, D, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
Expand Down
9 changes: 8 additions & 1 deletion python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
for p in partsToSizes:
aggregated_partsToSizes[p[0]][1] += p[1]

return f.fit(
ret_status = f.fit(
[(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank
)

if len(f.classes_) == 1:
raise ValueError(
f"This solver needs samples of at least 2 classes in the data, but the data contains only one class: {f.classes_[0]}"
)

return ret_status
1 change: 0 additions & 1 deletion python/cuml/linear_model/base_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class MGFitMixin(object):
check_dtype = self.dtype

if sparse_input:

X_m = SparseCumlArray(input_data[i][0], convert_index=np.int32)
_, self.n_cols = X_m.shape
else:
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
"with softmax (multinomial).")

if solves_classification and not solves_multiclass:
self._num_classes_dim = self._num_classes - 1
self._num_classes_dim = 1
else:
self._num_classes_dim = self._num_classes

Expand All @@ -185,7 +185,6 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def fit(self, input_data, n_rows, n_cols, parts_rank_size, rank, convert_dtype=False):

self.rank = rank
assert len(input_data) == 1, f"Currently support only one (X, y) pair in the list. Received {len(input_data)} pairs."
self.is_col_major = False
order = 'F' if self.is_col_major else 'C'
Expand All @@ -207,11 +206,12 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
self._num_classes = len(self.classes_)
self.loss = "sigmoid" if self._num_classes <= 2 else "softmax"
self.prepare_for_fit(self._num_classes)

cdef uintptr_t mat_coef_ptr = self.coef_.ptr

cdef qn_params qnpams = self.solver_model.qnparams.params

sparse_input = True if isinstance(X, list) else False
sparse_input = isinstance(X, list)

if self.dtype == np.float32:
if sparse_input is False:
Expand Down
56 changes: 44 additions & 12 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cuml.internals.safe_imports import gpu_only_import
import pytest
from cuml.dask.common import utils as dask_utils
from functools import partial
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression as skLR
Expand Down Expand Up @@ -339,12 +340,12 @@ def imp():
datatype, nrows, ncols, n_info, n_classes=n_classes
)

if convert_to_sparse is False:
# X_dask and y_dask are dask cudf
X_dask, y_dask = _prep_training_data(client, X, y, n_parts)
else:
if convert_to_sparse:
# X_dask and y_dask are dask array
X_dask, y_dask = _prep_training_data_sparse(client, X, y, n_parts)
else:
# X_dask and y_dask are dask cudf
X_dask, y_dask = _prep_training_data(client, X, y, n_parts)

lr = cumlLBFGS_dask(
solver="qn",
Expand Down Expand Up @@ -557,23 +558,21 @@ def test_elasticnet(
("elasticnet", 2.0, 0.2),
],
)
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("n_classes", [2, 8])
def test_sparse_from_dense(
fit_intercept, regularization, datatype, delayed, n_classes, client
fit_intercept, regularization, datatype, n_classes, client
):
penalty = regularization[0]
C = regularization[1]
l1_ratio = regularization[2]
penalty, C, l1_ratio = regularization

test_lbfgs(
run_test = partial(
test_lbfgs,
nrows=1e5,
ncols=20,
n_parts=2,
fit_intercept=fit_intercept,
datatype=datatype,
delayed=delayed,
delayed=True,
client=client,
penalty=penalty,
n_classes=n_classes,
Expand All @@ -582,6 +581,15 @@ def test_sparse_from_dense(
convert_to_sparse=True,
)

if datatype == np.float32:
run_test()
else:
with pytest.raises(
RuntimeError,
match="dtypes other than float32 are currently not supported",
):
run_test()


@pytest.mark.parametrize("dtype", [np.float32])
def test_sparse_nlp20news(dtype, nlp_20news, client):
Expand Down Expand Up @@ -621,3 +629,27 @@ def test_sparse_nlp20news(dtype, nlp_20news, client):
cpu_preds = cpu.predict(X_test)
cpu_score = accuracy_score(y_test, cpu_preds.tolist())
assert cuml_score >= cpu_score or np.abs(cuml_score - cpu_score) < 1e-3


@pytest.mark.parametrize("fit_intercept", [False, True])
def test_exception_one_label(fit_intercept, client):
n_parts = 2
datatype = "float32"

X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], datatype)
y = np.array([1.0, 1.0, 1.0, 1.0], datatype)
X_df, y_df = _prep_training_data(client, X, y, n_parts)

err_msg = "This solver needs samples of at least 2 classes in the data, but the data contains only one class: 1.0"

from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask

mg = cumlLBFGS_dask(fit_intercept=fit_intercept, verbose=6)
with pytest.raises(RuntimeError, match=err_msg):
mg.fit(X_df, y_df)

from sklearn.linear_model import LogisticRegression

lr = LogisticRegression(fit_intercept=fit_intercept)
with pytest.raises(ValueError, match=err_msg):
lr.fit(X, y)

0 comments on commit 529866b

Please sign in to comment.