Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add serialization for incremental linear models #2211

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions deselected_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,6 @@ deselected_tests:
- tests/test_common.py::test_estimators[LogisticRegression()-check_sample_weights_invariance(kind=zeros)] >=1.4
- tests/test_multioutput.py::test_classifier_chain_fit_and_predict_with_sparse_data >=1.4

# Deselected tests for incremental algorithms
# Need to rework getting policy to correctly obtain it for method without data (finalize_fit)
# and avoid keeping it in class attribute, also need to investigate how to implement
# partial result serialization
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle]
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_estimators_pickle(readonly_memmap=True)]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)]
# There are not enough data to run onedal backend
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_fit2d_1sample]
- tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample]
Expand Down
16 changes: 5 additions & 11 deletions onedal/datatypes/data_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,7 @@ PyObject *convert_to_pyobject(const dal::table &input) {
}
if (input.get_kind() == dal::homogen_table::kind()) {
const auto &homogen_input = static_cast<const dal::homogen_table &>(input);
if (homogen_input.get_data_layout() == dal::data_layout::row_major) {
const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0);
const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0);

#define MAKE_NYMPY_FROM_HOMOGEN(NpType) \
{ \
Expand All @@ -384,16 +383,11 @@ PyObject *convert_to_pyobject(const dal::table &input) {
homogen_input.get_row_count(), \
homogen_input.get_column_count()); \
}
SET_CTYPE_NPY_FROM_DAL_TYPE(
dtype,
MAKE_NYMPY_FROM_HOMOGEN,
throw std::invalid_argument("Not avalible to convert a numpy"));
SET_CTYPE_NPY_FROM_DAL_TYPE(
dtype,
MAKE_NYMPY_FROM_HOMOGEN,
throw std::invalid_argument("Not avalible to convert a numpy"));
#undef MAKE_NYMPY_FROM_HOMOGEN
}
else {
throw std::invalid_argument(
"Output oneDAL table doesn't have row major format for homogen table");
}
}
else if (input.get_kind() == csr_table_t::kind()) {
const auto &csr_input = static_cast<const csr_table_t &>(input);
Expand Down
102 changes: 69 additions & 33 deletions onedal/linear_model/incremental_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,22 @@ def __init__(self, fit_intercept=True, copy_X=False, algorithm="norm_eq"):
self._reset()

def _reset(self):
self._need_to_finalize = False
self._partial_result = self._get_backend(
"linear_model", "regression", "partial_train_result"
)

def __getstate__(self):
# Since finalize_fit can't be dispatched without directly provided queue
# and the dispatching policy can't be serialized, the computation is finalized
# here and the policy is not saved in serialized data.

self.finalize_fit()
data = self.__dict__.copy()
data.pop("_queue", None)

return data

def partial_fit(self, X, y, queue=None):
"""
Computes partial data for linear regression
Expand Down Expand Up @@ -106,6 +118,9 @@ def partial_fit(self, X, y, queue=None):
policy, self._params, self._partial_result, X_table, y_table
)

self._need_to_finalize = True
return self

def finalize_fit(self, queue=None):
"""
Finalizes linear regression computation and obtains coefficients
Expand All @@ -122,27 +137,30 @@ def finalize_fit(self, queue=None):
Returns the instance itself.
"""

if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

module = self._get_backend("linear_model", "regression")
hparams = get_hyperparameters("linear_regression", "train")
if hparams is not None and not hparams.is_default:
result = module.finalize_train(
policy, self._params, hparams.backend, self._partial_result
if self._need_to_finalize:
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

module = self._get_backend("linear_model", "regression")
hparams = get_hyperparameters("linear_regression", "train")
if hparams is not None and not hparams.is_default:
result = module.finalize_train(
policy, self._params, hparams.backend, self._partial_result
)
else:
result = module.finalize_train(policy, self._params, self._partial_result)

self._onedal_model = result.model

packed_coefficients = from_table(result.model.packed_coefficients)
self.coef_, self.intercept_ = (
packed_coefficients[:, 1:].squeeze(),
packed_coefficients[:, 0].squeeze(),
)
else:
result = module.finalize_train(policy, self._params, self._partial_result)

self._onedal_model = result.model

packed_coefficients = from_table(result.model.packed_coefficients)
self.coef_, self.intercept_ = (
packed_coefficients[:, 1:].squeeze(),
packed_coefficients[:, 0].squeeze(),
)
self._need_to_finalize = False

return self

Expand Down Expand Up @@ -171,15 +189,26 @@ class IncrementalRidge(BaseLinearRegression):
"""

def __init__(self, alpha=1.0, fit_intercept=True, copy_X=False, algorithm="norm_eq"):
module = self._get_backend("linear_model", "regression")
super().__init__(
fit_intercept=fit_intercept, alpha=alpha, copy_X=copy_X, algorithm=algorithm
)
self._partial_result = module.partial_train_result()
self._reset()

def _reset(self):
module = self._get_backend("linear_model", "regression")
self._partial_result = module.partial_train_result()
self._need_to_finalize = False

def __getstate__(self):
# Since finalize_fit can't be dispatched without directly provided queue
# and the dispatching policy can't be serialized, the computation is finalized
# here and the policy is not saved in serialized data.

self.finalize_fit()
data = self.__dict__.copy()
data.pop("_queue", None)

return data

def partial_fit(self, X, y, queue=None):
"""
Expand Down Expand Up @@ -225,6 +254,9 @@ def partial_fit(self, X, y, queue=None):
policy, self._params, self._partial_result, X_table, y_table
)

self._need_to_finalize = True
return self

def finalize_fit(self, queue=None):
"""
Finalizes ridge regression computation and obtains coefficients
Expand All @@ -240,19 +272,23 @@ def finalize_fit(self, queue=None):
self : object
Returns the instance itself.
"""
module = self._get_backend("linear_model", "regression")
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)
result = module.finalize_train(policy, self._params, self._partial_result)

self._onedal_model = result.model
if self._need_to_finalize:
module = self._get_backend("linear_model", "regression")
if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)
result = module.finalize_train(policy, self._params, self._partial_result)

packed_coefficients = from_table(result.model.packed_coefficients)
self.coef_, self.intercept_ = (
packed_coefficients[:, 1:].squeeze(),
packed_coefficients[:, 0].squeeze(),
)
self._onedal_model = result.model

packed_coefficients = from_table(result.model.packed_coefficients)
self.coef_, self.intercept_ = (
packed_coefficients[:, 1:].squeeze(),
packed_coefficients[:, 0].squeeze(),
)

self._need_to_finalize = False

return self
21 changes: 20 additions & 1 deletion onedal/linear_model/linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "onedal/common.hpp"
#include "onedal/version.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"

#include <regex>

namespace py = pybind11;
Expand Down Expand Up @@ -241,7 +244,23 @@ void init_partial_train_result(py::module_& m) {
py::class_<result_t>(m, "partial_train_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(partial_xtx, result_t)
.DEF_ONEDAL_PY_PROPERTY(partial_xty, result_t);
.DEF_ONEDAL_PY_PROPERTY(partial_xty, result_t)
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_xtx())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_xty()))
);
},
[](py::tuple t) {
if (t.size() != 2)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0) res.set_partial_xtx(convert_to_table(t[0].ptr()));
if (py::cast<int>(t[1].attr("size")) != 0) res.set_partial_xty(convert_to_table(t[1].ptr()));
return res;
}
));
}

template <typename Task>
Expand Down
92 changes: 68 additions & 24 deletions onedal/linear_model/tests/test_incremental_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from numpy.testing import assert_allclose
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

from onedal.datatypes import from_table
from onedal.linear_model import IncrementalLinearRegression
from onedal.tests.utils._device_selection import get_queues

Expand All @@ -43,29 +44,6 @@
assert mean_squared_error(y_test, y_pred) < 2396


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.skip(reason="pickling not implemented for oneDAL entities")
def test_pickle(queue, dtype):
# TODO Implement pickling for oneDAL entities
X, y = load_diabetes(return_X_y=True)
X, y = X.astype(dtype), y.astype(dtype)
model = IncrementalLinearRegression(fit_intercept=True)
model.partial_fit(X, y, queue=queue)
model.finalize_fit()
expected = model.predict(X, queue=queue)

import pickle

dump = pickle.dumps(model)
model2 = pickle.loads(dump)

assert isinstance(model2, model.__class__)
result = model2.predict(X, queue=queue)

assert_array_equal(expected, result)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("num_blocks", [1, 2, 10])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
Expand Down Expand Up @@ -166,3 +144,69 @@

tol = 1e-5 if res.dtype == np.float32 else 1e-7
assert_allclose(gtr, res, rtol=tol)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_incremental_estimator_pickle(queue, dtype):
import pickle

from onedal.linear_model import IncrementalLinearRegression

inclr = IncrementalLinearRegression()

# Check that estimator can be serialized without any data.
dump = pickle.dumps(inclr)
inclr_loaded = pickle.loads(dump)

Check warning on line 160 in onedal/linear_model/tests/test_incremental_linear_regression.py

View check run for this annotation

codefactor.io / CodeFactor

onedal/linear_model/tests/test_incremental_linear_regression.py#L160

Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. (B301)
seed = 77
gen = np.random.default_rng(seed)
X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10))
X = X.astype(dtype)
coef = gen.random(size=(1, 10), dtype=dtype).T
y = X @ coef
X_split = np.array_split(X, 2)
y_split = np.array_split(y, 2)
inclr.partial_fit(X_split[0], y_split[0], queue=queue)
inclr_loaded.partial_fit(X_split[0], y_split[0], queue=queue)

# inclr.finalize_fit()

assert inclr._need_to_finalize == True
assert inclr_loaded._need_to_finalize == True

# Check that estimator can be serialized after partial_fit call.
dump = pickle.dumps(inclr)
inclr_loaded = pickle.loads(dump)

Check warning on line 179 in onedal/linear_model/tests/test_incremental_linear_regression.py

View check run for this annotation

codefactor.io / CodeFactor

onedal/linear_model/tests/test_incremental_linear_regression.py#L179

Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. (B301)

partial_xtx = from_table(inclr._partial_result.partial_xtx)
partial_xtx_loaded = from_table(inclr_loaded._partial_result.partial_xtx)
assert_allclose(partial_xtx, partial_xtx_loaded)

partial_xty = from_table(inclr._partial_result.partial_xty)
partial_xty_loaded = from_table(inclr_loaded._partial_result.partial_xty)
assert_allclose(partial_xty, partial_xty_loaded)

assert inclr._need_to_finalize == False
# Finalize is called during serialization to make sure partial results are finalized correctly.
assert inclr_loaded._need_to_finalize == False

inclr.partial_fit(X_split[1], y_split[1], queue=queue)
inclr_loaded.partial_fit(X_split[1], y_split[1], queue=queue)
assert inclr._need_to_finalize == True
assert inclr_loaded._need_to_finalize == True

dump = pickle.dumps(inclr_loaded)
inclr_loaded = pickle.loads(dump)

Check warning on line 199 in onedal/linear_model/tests/test_incremental_linear_regression.py

View check run for this annotation

codefactor.io / CodeFactor

onedal/linear_model/tests/test_incremental_linear_regression.py#L199

Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. (B301)

assert inclr._need_to_finalize == True
assert inclr_loaded._need_to_finalize == False

inclr.finalize_fit()
inclr_loaded.finalize_fit()

# Check that finalized estimator can be serialized.
dump = pickle.dumps(inclr_loaded)
inclr_loaded = pickle.loads(dump)

Check warning on line 209 in onedal/linear_model/tests/test_incremental_linear_regression.py

View check run for this annotation

codefactor.io / CodeFactor

onedal/linear_model/tests/test_incremental_linear_regression.py#L209

Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. (B301)

assert_allclose(inclr.coef_, inclr_loaded.coef_, atol=1e-6)
assert_allclose(inclr.intercept_, inclr_loaded.intercept_, atol=1e-6)
Loading
Loading