diff --git a/sinr/text/evaluate.py b/sinr/text/evaluate.py index bd28c90..80edad7 100644 --- a/sinr/text/evaluate.py +++ b/sinr/text/evaluate.py @@ -3,11 +3,13 @@ import scipy from scipy import stats from sklearn.datasets._base import Bunch +import sklearn.metrics as metrics import pandas as pd import urllib.request import os from tqdm.auto import tqdm import time +import xgboost as xgb def fetch_data_MEN(): """Fetch MEN dataset for testing relatedness similarity @@ -293,6 +295,8 @@ def vectorizer(sinr_vec, X, y=[]): :type X: text (list(list(str))): A list of documents containing words :param y: documents labels :type y: numpy.ndarray + + :returns: list of vectors """ if len(y) > 0 and len(X) != len(y): @@ -317,3 +321,38 @@ def vectorizer(sinr_vec, X, y=[]): y = list(map(int,y)) return vectors, y + +def clf_fit(X_train, y_train, clf=xgb.XGBClassifier()): + """Fit a classification model according to the given training data. + :param X_train: training data + :type X_train: list of vectors + :param y_train: labels + :type y_train: numpy.ndarray + :param clf: classifier + :type clf: classifier (ex.: xgboost.XGBClassifier, sklearn.svm.SVC) + + :returns: Fitted classifier + :rtype: classifier + """ + clf.fit(X_train, y_train) + return clf + +def clf_score(clf, X_test, y_test, scoring='accuracy', params={}): + """Evaluate classification on given test data. + :param clf: classifier + :type clf: classifier (ex.: xgboost.XGBClassifier, sklearn.svm.SVC) + :param X_test: test data + :type X_test: list of vectors + :param y_test: labels + :type y_test: numpy.ndarray + :param scoring: scikit-learn scorer object, default='accuracy' + :type scoring: str + :param params: parameters for the scorer object + :type params: dictionary + + :returns: Score + :rtype: float + """ + score = getattr(metrics, scoring+'_score') + y_pred = clf.predict(X_test) + return score(y_test, y_pred, **params) diff --git a/tests/test_sinr_evaluate.py b/tests/test_sinr_evaluate.py index f3a4464..0b0e743 100644 --- a/tests/test_sinr_evaluate.py +++ b/tests/test_sinr_evaluate.py @@ -7,7 +7,7 @@ import unittest import sinr.graph_embeddings as ge -from sinr.text.evaluate import fetch_data_MEN, fetch_data_WS353, eval_similarity, similarity_MEN_WS353_SCWS, vectorizer +from sinr.text.evaluate import fetch_data_MEN, fetch_data_WS353, eval_similarity, similarity_MEN_WS353_SCWS, vectorizer, clf_fit, clf_score import urllib.request import os @@ -34,9 +34,13 @@ def setUp(self): # datas for classification X_train = [['goodbye', 'please', 'love'],[],['no', 'yes', 'friend', 'family', 'happy'],['a', 'the'],['beautiful','small','a']] y_train = [0,0,1,0,1] + X_test = [['goodbye', 'family', 'friend'],['beautiful', 'small'],['please','happy','love']] + y_test = [0,1,0] self.X_train = X_train self.y_train = y_train + self.X_test = X_test + self.y_test = y_test def tearDown(self): """Tear down test fixtures, if any.""" @@ -54,6 +58,16 @@ def test_similarity_MEN_WS353_SCWS(self): def test_vectorize(self): X, y = vectorizer(self.vectors, self.X_train, y=self.y_train) self.assertTrue(len(X) == len(y)) + + def test_clf_fit_and_score(self): + X_train, y_train = vectorizer(self.vectors, self.X_train, y=self.y_train) + X_test, y_test = vectorizer(self.vectors, self.X_test, y=self.y_test) + + clf = clf_fit(X_train, y_train) + score = clf_score(clf, X_test, y_test) + + self.assertGreater(round(score,2), 0.66) + if __name__ == '__main__': unittest.main() \ No newline at end of file