Skip to content

Commit

Permalink
After paper submission, added setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeurer committed Jul 1, 2014
1 parent ff56d58 commit d4334e3
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 20 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
~
*.pyc
*~
.idea
dist/
AutoSklearn.egg-info

31 changes: 29 additions & 2 deletions AutoSklearn/autosklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
assert classifier is None
assert preprocessor is None
classifier = parameters.get("classifier")
preprocessor = parameters.get("preprocessor")
preprocessor = parameters.get("preprocessing")
if preprocessor == "None":
preprocessor = None

Expand All @@ -82,10 +82,35 @@ def __init__(self,

# TODO: make sure that there are no duplicate classifiers
self._available_classifiers = classification_components._classifiers
classifier_parameters = set()
for _classifier in self._available_classifiers:
accepted_hyperparameter_names = self._available_classifiers[_classifier] \
.get_all_accepted_hyperparameter_names()
name = self._available_classifiers[_classifier].get_hyperparameter_search_space()['name']
for key in accepted_hyperparameter_names:
classifier_parameters.add("%s:%s" % (name, key))

self._available_preprocessors = preprocessing_components._preprocessors
preprocessor_parameters = set()
for _preprocessor in self._available_preprocessors:
accepted_hyperparameter_names = self._available_preprocessors[_preprocessor] \
.get_all_accepted_hyperparameter_names()
name = self._available_preprocessors[_preprocessor].get_hyperparameter_search_space()['name']
for key in accepted_hyperparameter_names:
preprocessor_parameters.add("%s:%s" % (name, key))

for parameter in self.parameters:
if parameter not in classifier_parameters and \
parameter not in preprocessor_parameters and \
parameter not in ("preprocessing", "classifier", "name"):
print "Classifier parameters %s" % str(classifier_parameters)
print "Preprocessing parameters %s" % str(preprocessor_parameters)
raise ValueError("Parameter %s is unknown." % parameter)

if random_state is None:
random_state = check_random_state(1)
self.random_state = check_random_state(1)
else:
self.random_state = check_random_state(random_state)

self._estimator_class = self._available_classifiers.get(classifier)
if classifier is not None and self._estimator_class is None:
Expand All @@ -99,6 +124,8 @@ def __init__(self,
"of preprocessors found on this system: %s" %
(preprocessor, self._available_preprocessors))



def fit(self, X, Y):
# TODO: perform input validation
# TODO: look if X.shape[0] == y.shape[0]
Expand Down
7 changes: 6 additions & 1 deletion AutoSklearn/components/classification/liblinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from ..classification_base import AutoSklearnClassificationAlgorithm

class LibLinear_SVC(AutoSklearnClassificationAlgorithm):
# TODO: maybe add dual and crammer-singer?
# Liblinear is not deterministic as it uses a RNG inside
# TODO: maybe add dual and crammer-singer?
def __init__(self, penalty="l2", loss="l2", C=1.0, LOG2_C=None, random_state=None):
self.penalty = penalty
self.loss = loss
Expand Down Expand Up @@ -52,5 +53,9 @@ def get_hyperparameter_search_space():
return {"name": "liblinear", "penalty_and_loss": penalty_and_loss,
"LOG2_C": LOG2_C}

@staticmethod
def get_all_accepted_hyperparameter_names():
return (["LOG2_C", "C", "penalty", "loss"])

def __str__(self):
return "AutoSklearn Liblinear Classifier"
9 changes: 7 additions & 2 deletions AutoSklearn/components/classification/libsvm_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def handles_non_binary_classes(self):
@staticmethod
def get_hyperparameter_search_space():
LOG2_C = hp_uniform("LOG2_C", -5, 15)
LOG2_gamma = hp_uniform("LOG2_gamma", -15, 5)
return {"name": "libsmv_scv", "LOG2_C": LOG2_C, "LOG2_gamma": LOG2_gamma}
LOG2_gamma = hp_uniform("LOG2_gamma", -15, 3)
return {"name": "libsvm_svc", "LOG2_C": LOG2_C, "LOG2_gamma":
LOG2_gamma}

@staticmethod
def get_all_accepted_hyperparameter_names():
return (["LOG2_C", "C", "LOG2_gamma", "gamma"])

def __str__(self):
return "AutoSklearn LibSVM Classifier"
7 changes: 6 additions & 1 deletion AutoSklearn/components/classification/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fit(self, X, Y):
self.max_depth = int(self.max_depth)
self.min_samples_split = int(self.min_samples_split)
self.min_samples_leaf = int(self.min_samples_leaf)
if self.max_features not in ("sqrt", ("log2")):
if self.max_features not in ("sqrt", "log2", "auto"):
self.max_features = float(self.max_features)

self.estimator = sklearn.ensemble.RandomForestClassifier(
Expand Down Expand Up @@ -72,5 +72,10 @@ def get_hyperparameter_search_space():
min_samples_split, "min_samples_leaf": min_samples_leaf,
"bootstrap": bootstrap}

@staticmethod
def get_all_accepted_hyperparameter_names():
return (["n_estimators", "criterion", "max_features",
"min_samples_split", "min_samples_leaf", "bootstrap"])

def __str__(self):
return "AutoSklearn LibSVM Classifier"
3 changes: 3 additions & 0 deletions AutoSklearn/components/classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def handles_non_binary_classes(self):
def get_hyperparameter_search_space(self):
raise NotImplementedError()

def get_all_accepted_hyperparameter_names(self):
raise NotImplementedError()

def fit(self, X, Y):
raise NotImplementedError()

Expand Down
12 changes: 10 additions & 2 deletions AutoSklearn/components/preprocessing/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ def __init__(self, keep_variance=1.0, whiten=False, random_state=None):
self.whiten = whiten

def fit(self, X, Y):
# TODO: implement that keep_variance can be a percentage (in int)
self.preprocessor = sklearn.decomposition.PCA(whiten=self.whiten,
copy=True)
# num components is
# selected further down
# the code
self.preprocessor.fit(X, Y)

sum_ = 0.
Expand Down Expand Up @@ -42,10 +46,14 @@ def handles_non_binary_classes(self):

@staticmethod
def get_hyperparameter_search_space():
keep_variance = hp_uniform("n_components", 0.5, 1.0)
keep_variance = hp_uniform("keep_variance", 0.5, 1.0)
whiten = hp_choice("whiten", ["False", "True"])
return {"name": "pca", "keep_variance": keep_variance,
"whiten": whiten}

@staticmethod
def get_all_accepted_hyperparameter_names():
return (["keep_variance", "whiten"])

def __str__(self):
return "AutoSklearn Principle Component Analysis preprocessor."
return "AutoSklearn Principle Component Analysis preprocessor."
3 changes: 3 additions & 0 deletions AutoSklearn/components/preprocessor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def handles_non_binary_classes(self):
def get_hyperparameter_search_space(self):
raise NotImplementedError()

def get_all_accepted_hyperparameter_names():
raise NotImplementedError()

def fit(self, X, Y):
raise NotImplementedError()

Expand Down
Empty file added CHANGES.md
Empty file.
Empty file added LICENSE.txt
Empty file.
Empty file added README.md
Empty file.
15 changes: 15 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import setuptools

setuptools.setup(name="AutoSklearn",
description="Scikit-Learn wrapper for automatic "
"hyperparameter configuration.",
version="0.1dev",
packages=setuptools.find_packages(),
install_requires=["scikit_learn==0.14.1"],
package_data={'': ['*.txt', '*.md']},
author="Matthias Feurer",
author_email="[email protected]",
license="BSD",
platforms=['Linux'],
classifiers=[]
url="github.com/mfeurer/autosklearn")
15 changes: 10 additions & 5 deletions tests/test_all_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ def test_all_combinations(self):
#for n_components, whiten in itertools.product(pca_n_components):
#pca_whiten):
for n_components in pca_n_components:
pca.append({"pca:n_components": n_components,
pca.append({"pca:keep_variance": n_components,
#"pca:whiten": whiten,
"preprocessor": "pca"})
"preprocessing": "pca"})
print "Parameter configurations PCA", len(pca)

classifiers = [liblinear, libsvm_svc, random_forest]
preprocessors = [pca, [{"preprocessor": None}]]
preprocessors = [pca, [{"preprocessing": None}]]

for classifier, preprocessor in itertools.product(classifiers,
preprocessors):
print classifier[0]["classifier"], preprocessor[0]["preprocessor"]
print classifier[0]["classifier"], preprocessor[0]["preprocessing"]
for classifier_params, preprocessor_params in itertools.product(
classifier, preprocessor):
params = {}
Expand All @@ -110,7 +110,12 @@ def test_all_combinations(self):
for i, parameter_combination in enumerate(parameter_combinations):
auto = AutoSklearnClassifier(parameters=parameter_combination)
X_train, Y_train, X_test, Y_test = self.get_iris()
auto = auto.fit(X_train, Y_train)
try:
auto = auto.fit(X_train, Y_train)
except Exception as e:
print parameter_combination
print (parameter_combination['random_forest:max_features'] * X_train.shape[1])
raise e
predictions = auto.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(Y_test, predictions)

Expand Down
56 changes: 50 additions & 6 deletions tests/test_autosklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import sklearn.datasets
import sklearn.decomposition
import sklearn.ensemble
import sklearn.svm

from AutoSklearn.autosklearn import AutoSklearnClassifier
from AutoSklearn.components.classification_base import AutoSklearnClassificationAlgorithm
Expand All @@ -17,6 +19,9 @@
from AutoSklearn.util import NoModelException

class TestAutoSKlearnClassifier(unittest.TestCase):
# TODO: test for both possible ways to initialize AutoSklearn
# parameters and other...

def get_iris(self):
iris = sklearn.datasets.load_iris()
X = iris.data
Expand Down Expand Up @@ -63,20 +68,41 @@ def test_init_unknown_classifier(self):
def test_init_parameters_as_dict_or_as_keywords(self):
pass

def test_fit_iris(self):
auto = AutoSklearnClassifier("liblinear", None)
def test_predict_iris(self):
auto = AutoSklearnClassifier(parameters={"classifier": "liblinear",
"preprocessing": None})
X_train, Y_train, X_test, Y_test = self.get_iris()
auto = auto.fit(X_train, Y_train)
predictions = auto.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(Y_test, predictions)
self.assertIsInstance(auto, AutoSklearnClassifier)
self.assertIsInstance(auto._estimator, AutoSklearnClassificationAlgorithm)
self.assertIsInstance(auto._estimator.estimator, sklearn.svm.LinearSVC)
self.assertAlmostEqual(accuracy, 1.0)

def test_predict_iris(self):
auto = AutoSklearnClassifier("liblinear", None)
def test_predict_svm(self):
auto = AutoSklearnClassifier(parameters={"classifier": "libsvm_svc",
"preprocessing": None})
X_train, Y_train, X_test, Y_test = self.get_iris()
auto = auto.fit(X_train, Y_train)
predictions = auto.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(Y_test, predictions)
self.assertAlmostEqual(accuracy, 1.0)
self.assertIsInstance(auto, AutoSklearnClassifier)
self.assertIsInstance(auto._estimator, AutoSklearnClassificationAlgorithm)
self.assertIsInstance(auto._estimator.estimator, sklearn.svm.SVC)
self.assertAlmostEqual(accuracy, 0.959999999999)

def test_predict_iris_rf(self):
auto = AutoSklearnClassifier(parameters={"classifier": "random_forest",
"preprocessing": None})
X_train, Y_train, X_test, Y_test = self.get_iris()
auto = auto.fit(X_train, Y_train)
predictions = auto.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(Y_test, predictions)
self.assertIsInstance(auto, AutoSklearnClassifier)
self.assertIsInstance(auto._estimator, AutoSklearnClassificationAlgorithm)
self.assertIsInstance(auto._estimator.estimator, sklearn.ensemble.RandomForestClassifier)
self.assertAlmostEqual(accuracy, 0.959999999999)

def test_fit_with_preproc(self):
auto = AutoSklearnClassifier("liblinear", "pca")
Expand All @@ -102,11 +128,29 @@ def test_specify_hyperparameters(self):
"random_forest:max_features": 1.0})
X_train, Y_train, X_test, Y_test = self.get_iris()
auto = auto.fit(X_train, Y_train)
self.assertIsNotNone(auto._preprocessor)
self.assertIsNotNone(auto._preprocessor.preprocessor)
self.assertIsNotNone(auto._estimator)
self.assertIsNotNone(auto._estimator.estimator)
predictions = auto.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(Y_test, predictions)
self.assertAlmostEqual(accuracy, 0.939999999)
self.assertAlmostEqual(accuracy, 0.92)
self.assertEqual(auto._estimator.estimator.n_estimators, 1)

def test_specify_unknown_hyperparameters(self):
self.assertRaisesRegexp(ValueError,
"Parameter random_forest:blablabla is unknown.",
AutoSklearnClassifier, random_state=1,
parameters={"classifier": "random_forest",
"preprocessing": "pca",
"random_forest:blablabla": 1})
self.assertRaisesRegexp(ValueError,
"Parameter pca:blablabla is unknown.",
AutoSklearnClassifier, random_state=1,
parameters={"classifier": "random_forest",
"preprocessing": "pca",
"pca:blablabla": 1})

def test_get_hyperparameter_search_space(self):
auto = AutoSklearnClassifier(None, None)
space = auto.get_hyperparameter_search_space()
Expand Down

0 comments on commit d4334e3

Please sign in to comment.