-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_utils.py
63 lines (51 loc) · 2.17 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pytest
from utils import split_train_dev_test,read_digits,preprocess_data,tune_hparams
import os
from joblib import dump,load
def inc(x):
return x + 1
def test_inc():
assert inc(4) == 5
def create_dummy_hyperparamete():
gama_ranges = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
C_ranges = [0.1,1,2,5,10]
list_of_all_param_combination = [{'gamma': gamma, 'C': C} for gamma in gama_ranges for C in C_ranges]
return list_of_all_param_combination
def create_dummy_data():
X,y = read_digits()
X_train = X[:100,:,:]
y_train = y[:100]
X_dev = X[:50,:,:]
y_dev = y[:50]
X_train = preprocess_data(X_train)
X_dev = preprocess_data(X_dev)
return X_train, y_train, X_dev, y_dev
def test_hparam_count():
list_of_all_param_combination = create_dummy_hyperparamete()
assert len(list_of_all_param_combination) == 35
def test_mode_saving():
X_train, y_train, X_dev, y_dev = create_dummy_data()
list_of_all_param_combination = create_dummy_hyperparamete()
_, best_model_path, _ = tune_hparams(X_train, y_train, X_dev, y_dev, list_of_all_param_combination,'svm')
assert os.path.exists(best_model_path)
def test_data_splitting():
X,y = read_digits()
X = X[:100,:,:]
y = y[:100]
test_size = 0.1
dev_size = 0.6
train_size = 1 - (dev_size + test_size)
X_train, X_test,X_dev, y_train, y_test,y_dev = split_train_dev_test(X, y, test_size=test_size, dev_size=dev_size);
assert len(X_train) == int(train_size * len(X)) and len(X_test) == int(test_size * len(X)) and len(X_dev) == int(dev_size * len(X))
def test_model_is_lr():
solvers = ['lbfgs', 'liblinear', 'newton-cg', 'sag', 'saga']
for solver in solvers:
loaded_model = load("./models/m23csa018_lr_{}".format(solver)+".joblib")
#check if loaded model is logistic regression model
assert loaded_model.__class__.__name__ == 'LogisticRegression'
def test_model_is_lr_solver():
solvers = ['lbfgs', 'liblinear', 'newton-cg', 'sag', 'saga']
for solver in solvers:
loaded_model = load("./models/m23csa018_lr_{}".format(solver)+".joblib")
#check if loaded model has correct solver
assert loaded_model.solver == solver