Skip to content

Commit

Permalink
Fix LogisticRegression.decision_function output shape
Browse files Browse the repository at this point in the history
Previously this would return `(n_classes, n_rows)` for multiclass,
whereas sklearn returns `(n_rows, n_classes)`.
  • Loading branch information
jcrist committed Jan 17, 2025
1 parent d95cae5 commit ede0278
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
4 changes: 1 addition & 3 deletions python/cuml/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,7 @@ class LogisticRegression(UniversalBase,
log_proba=False) -> CumlArray:
_num_classes = self.classes_.shape[0]

scores = cp.asarray(
self.decision_function(X, convert_dtype=convert_dtype), order="F"
).T
scores = self.decision_function(X, convert_dtype=convert_dtype).to_output("cupy")
if _num_classes == 2:
proba = cp.zeros((scores.shape[0], 2))
proba[:, 1] = 1 / (1 + cp.exp(-scores.ravel()))
Expand Down
8 changes: 5 additions & 3 deletions python/cuml/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -662,7 +662,9 @@ class QN(Base,
Returns
----------
y: array-like (device)
Dense matrix (floats or doubles) of shape (n_samples, n_classes)
Dense matrix (floats or doubles) of shape (n_samples,), or
(n_samples, n_classes) if more than 2 classes.

"""
coefs = self.coef_
dtype = coefs.dtype
Expand Down Expand Up @@ -776,7 +778,7 @@ class QN(Base,

del X_m

return scores
return scores.to_output("array").T

@generate_docstring(
X='dense_sparse',
Expand Down
2 changes: 0 additions & 2 deletions python/cuml/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,6 @@ def test_logistic_regression_decision_function(
sklog.classes_ = np.arange(num_classes)

cu_dec_func = culog.decision_function(X_test)
if cu_dec_func.shape[0] > 2: # num_classes
cu_dec_func = cu_dec_func.T
sk_dec_func = sklog.decision_function(X_test)

assert array_equal(cu_dec_func, sk_dec_func)
Expand Down

0 comments on commit ede0278

Please sign in to comment.