Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
improve coverage and support older numpy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vene committed Dec 21, 2016
1 parent 93fefc6 commit f98f2b1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
27 changes: 20 additions & 7 deletions polylearn/tests/test_adagrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from nose.tools import assert_less_equal

import numpy as np
from numpy.testing import assert_array_almost_equal, assert_raises_regex
from numpy.testing import assert_array_almost_equal, assert_raises

try:
from numpy.testing import assert_raises_regex
has_assert_raises_regex = True
except ImportError:
has_assert_raises_regex = False

import scipy.sparse as sp

Expand Down Expand Up @@ -191,11 +197,18 @@ def test_predict_sensible_error():
fit_linear=False, fit_lower=None,
max_iter=3, random_state=0)
reg.fit(X, y)
assert_raises_regex(ValueError,
"Incompatible dimensions",
reg.predict,
X[:, :2])
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
assert_raises_regex(ValueError, "wrong order", reg.predict, X)
if has_assert_raises_regex:
assert_raises_regex(ValueError,
"Incompatible dimensions",
reg.predict,
X[:, :2])
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
assert_raises_regex(ValueError, "wrong order", reg.predict, X)
else:
# if assert_raises_regex is not available, use looser test
assert_raises(ValueError, reg.predict, X[:, :2])
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
assert_raises(ValueError, "wrong order")



32 changes: 31 additions & 1 deletion polylearn/tests/test_factorization_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nose.tools import assert_less_equal, assert_equal

import numpy as np
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_almost_equal, assert_raises

from sklearn.metrics import mean_squared_error
from sklearn.utils.testing import assert_warns_message
Expand Down Expand Up @@ -340,3 +340,33 @@ def check_warm_start(degree):
def test_warm_start():
yield check_warm_start, 2
yield check_warm_start, 3


def test_lambdas():
"""Check that +/-1 lambdas lead to better train error for even degree."""
y = _poly_predict(X, P, lams, kernel="anova", degree=2)

est = FactorizationMachineRegressor(degree=2, n_components=5,
fit_linear=False, fit_lower=None,
beta=0.1, random_state=0)
y_pred_ones = est.fit(X, y).predict(X)
err_ones = mean_squared_error(y, y_pred_ones)

est.set_params(init_lambdas='random_signs')
y_pred_signs = est.fit(X, y).predict(X)
err_signs = mean_squared_error(y, y_pred_signs)

assert_less_equal(err_signs, err_ones)


def test_unsupported_errors():
y = _poly_predict(X, P, lams, kernel="anova", degree=2)

est = FactorizationMachineRegressor(degree=10, n_components=5,
fit_linear=False, fit_lower=None,
beta=0.1, random_state=0)

assert_raises(NotImplementedError, est.fit, X, y)

est.set_params(solver='adagrad', init_lambdas='random_signs')
assert_raises(NotImplementedError, est.fit, X, y)

0 comments on commit f98f2b1

Please sign in to comment.