Skip to content

Commit

Permalink
78 classification (#80)
Browse files Browse the repository at this point in the history
* preprocess : minimal length of documents kept + tests

* vectorizer + test

* classification's methods + tests

* xgboost interpretable dimensions

* adding xgboost for test workflow

* classification, fit and score test modification
  • Loading branch information
aberanger authored Mar 28, 2024
1 parent 0a1bd74 commit 47e9c2b
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ setuptools==66.0.0
spacy==3.5.3
tqdm==4.65.0
tabulate==0.9.0
xgboost==2.0.3
109 changes: 109 additions & 0 deletions sinr/text/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import scipy
from scipy import stats
from sklearn.datasets._base import Bunch
import sklearn.metrics as metrics
import pandas as pd
import urllib.request
import os
from tqdm.auto import tqdm
import time
import xgboost as xgb

def fetch_data_MEN():
"""Fetch MEN dataset for testing relatedness similarity
Expand Down Expand Up @@ -283,3 +285,110 @@ def dist_ratio_dim(sinr_vec, dim, union=None, prctbot=50, prcttop=10, nbtopk=5,
print("dimension",dim,"inter nulle", topks)
return 0
return intra / inter

def vectorizer(sinr_vec, X, y=[]):
"""Vectorize preprocessed documents to sinr embeddings
:param sinr_vec: SINrVectors object
:type sinr_vec: SINrVectors
:param X: preprocessed documents
:type X: text (list(list(str))): A list of documents containing words
:param y: documents labels
:type y: numpy.ndarray
:returns: list of vectors
"""

if len(y) > 0 and len(X) != len(y):
raise ValueError("X and y must be the same size")

indexes = set()
vectors = list()

# Doc to embedding
for i, doc in enumerate(X):
doc_vec = [sinr_vec._get_vector(sinr_vec._get_index(token)) for token in doc if token in sinr_vec.vocab]
if len(doc_vec) == 0:
indexes.add(i)
else:
vectors.append(np.mean(doc_vec, axis=0))

# Delete labels of:
#- empty documents
#- documents with only unknown vocabulary
if len(y) > 0:
y = np.delete(y, list(indexes))
y = list(map(int,y))

return vectors, y

def clf_fit(X_train, y_train, clf=xgb.XGBClassifier()):
"""Fit a classification model according to the given training data.
:param X_train: training data
:type X_train: list of vectors
:param y_train: labels
:type y_train: numpy.ndarray
:param clf: classifier
:type clf: classifier (ex.: xgboost.XGBClassifier, sklearn.svm.SVC)
:returns: Fitted classifier
:rtype: classifier
"""
clf.fit(X_train, y_train)
return clf

def clf_score(clf, X_test, y_test, scoring='accuracy', params={}):
"""Evaluate classification on given test data.
:param clf: classifier
:type clf: classifier (ex.: xgboost.XGBClassifier, sklearn.svm.SVC)
:param X_test: test data
:type X_test: list of vectors
:param y_test: labels
:type y_test: numpy.ndarray
:param scoring: scikit-learn scorer object, default='accuracy'
:type scoring: str
:param params: parameters for the scorer object
:type params: dictionary
:returns: Score
:rtype: float
"""
score = getattr(metrics, scoring+'_score')
y_pred = clf.predict(X_test)
return score(y_test, y_pred, **params)

def clf_xgb_interpretability(sinr_vec, xgb, interpreter,topk_dim=10, topk=5, importance_type='gain'):
"""Interpretability of main dimensions used by the xgboost classifier
:param sinr_vec: SINrVectors object from which datas were vectorized
:type sinr_vec: SINrVectors
:param xgb: fitted xgboost classifier
:type xgb: xgboost.XGBClassifier
:param interpreter: whether stereotypes or descriptors are requested
:type interpreter: str
:param topk_dim: Number of features requested among the main features used by the classifier (Default value = 10)
:type topk_dim: int
:param topk: `topk` value to consider on each dimension (Default value = 5)
:type topk: int
:param importance_type: ‘weight’: the number of times a feature is used to split the data across all trees,
‘gain’: the average gain across all splits the feature is used in,
‘cover’: the average coverage across all splits the feature is used in,
‘total_gain’: the total gain across all splits the feature is used in
‘total_cover’: the total coverage across all splits the feature is used in
:type importance_type: str
:returns: Interpreters of dimensions, importance of dimensions
:rtype: list of set of object, list of tuple
"""

features = xgb.get_booster().get_score(importance_type=importance_type)
features = dict(sorted(features.items(), key=lambda x: x[1], reverse=True))
features_index = [int(f[1:]) for f in list(features.keys())[:topk_dim]]
features_importance = list(features.items())[:topk_dim]

if interpreter=='descriptors':
dim = [sinr_vec.get_dimension_descriptors_idx(index, topk=topk) for index in features_index]
elif interpreter=='stereotypes':
dim = [sinr_vec.get_dimension_stereotypes_idx(index, topk=topk) for index in features_index]

return dim, features_importance
8 changes: 5 additions & 3 deletions sinr/text/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def do_txt_to_vrt(self, separator='sentence'):
if separator=='sentence':
input_txt = input_file.read().splitlines() # Read INPUT_FILE
else:
input_txt = input_file.read().split(separator)
input_txt = input_file.read().split(separator)[1:] # First tag make an empty document
logger.info(str(len(input_txt)) + "lines to preprocess")
input_file.close()

Expand Down Expand Up @@ -158,7 +158,7 @@ def do_txt_to_vrt(self, separator='sentence'):
corpus_opened.close()
logger.info(f"VRT-style file written in {self.corpus_output.absolute()}")

def extract_text(corpus_path, exceptions_path=None, lemmatize=True, stop_words=False, lower_words=True, number=False, punct=False, exclude_pos=[], en="chunking", min_freq=50, alpha=True, exclude_en=[], min_length_word=3):
def extract_text(corpus_path, exceptions_path=None, lemmatize=True, stop_words=False, lower_words=True, number=False, punct=False, exclude_pos=[], en="chunking", min_freq=50, alpha=True, exclude_en=[], min_length_word=3, min_length_doc=2):
"""Extracts the text from a VRT corpus file.
:param corpus_path: str
Expand All @@ -174,6 +174,8 @@ def extract_text(corpus_path, exceptions_path=None, lemmatize=True, stop_words=F
:param exclude_en: list (Default value = [])
:param lower_words: (Default value = True)
:param min_length_word: (Default value = 3)
:param min_length_doc: The minimal number of token for a document (or sentence) to be kept (Default value = 2)
:type min_length_doc: int
:returns: text (list(list(str))): A list of documents containing words
"""
Expand All @@ -200,7 +202,7 @@ def extract_text(corpus_path, exceptions_path=None, lemmatize=True, stop_words=F
if line.startswith("<s>"):
document = []
elif line.startswith("</s>"):
if len(document) > 2:
if len(document) > min_length_doc:
out.append(document)
elif len(pattern.findall(line)) > 0:
pass
Expand Down
24 changes: 23 additions & 1 deletion tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.datasets import fetch_20newsgroups



class TestSinr_embeddings(unittest.TestCase):
"""Tests for `graph_embeddings` package."""

Expand All @@ -19,6 +20,8 @@ def setUp(self):

txt_path = './ppcs_test.txt'
vrt_path = './ppcs_test.vrt'
txt_empty_docs_path = './ppcs_test_empty_docs.txt'
vrt_empty_docs_path = './ppcs_test_empty_docs.vrt'
doc_separator = '##D'
s0 = doc_separator + " At 10 a.m, in the heart of New York City, John Smith walked briskly towards the Empire State Building. "
text = ( s0 +
Expand All @@ -38,8 +41,13 @@ def setUp(self):
with open(txt_path, 'w+') as file:
file.write(text)
file.close()
with open(txt_empty_docs_path, 'w+') as file:
file.write(doc_separator + ' ' + doc_separator + ' ' + doc_separator + ' ' + doc_separator + ' ')
file.close()
self.txt_path = txt_path
self.vrt_path = vrt_path
self.txt_empty_docs_path = './ppcs_test_empty_docs.txt'
self.vrt_empty_docs_path = './ppcs_test_empty_docs.vrt'
self.n_doc = 4
self.n_sent = 10
self.doc_separator = doc_separator
Expand All @@ -48,7 +56,11 @@ def setUp(self):
def tearDown(self):
"""Tear down test fixtures, if any."""
os.remove(self.txt_path)
os.remove(self.vrt_path)
os.remove(self.txt_empty_docs_path)
if os.path.isfile(self.vrt_path):
os.remove(self.vrt_path)
else:
os.remove(self.vrt_empty_docs_path)

def test_doc_separator(self):
"""Testing if the preprocessed datas have the right number of documents"""
Expand Down Expand Up @@ -98,6 +110,16 @@ def test_preprocessed(self):
ok.append(token.text.lower() == s[ind])

self.assertTrue(False not in ok)

def test_preprocessing_empty_docs(self):
"""Testing min_length_doc = -1 : documents of all sizes are kept"""
vrt_maker = ppcs.VRTMaker(ppcs.Corpus(ppcs.Corpus.REGISTER_WEB,
ppcs.Corpus.LANGUAGE_EN,
self.txt_empty_docs_path),
".", n_jobs=8, spacy_size='sm')
vrt_maker.do_txt_to_vrt(separator=self.doc_separator)
docs = ppcs.extract_text(self.vrt_empty_docs_path, min_freq=1, min_length_doc=-1)
self.assertTrue(len(docs) == self.n_doc)


if __name__ == '__main__':
Expand Down
27 changes: 26 additions & 1 deletion tests/test_sinr_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import sinr.graph_embeddings as ge
from sinr.text.evaluate import fetch_data_MEN, fetch_data_WS353, eval_similarity, similarity_MEN_WS353_SCWS
from sinr.text.evaluate import fetch_data_MEN, fetch_data_WS353, eval_similarity, similarity_MEN_WS353_SCWS, vectorizer, clf_fit, clf_score
import urllib.request
import os

Expand All @@ -30,6 +30,17 @@ def setUp(self):
vectors = ge.SINrVectors('oanc')
vectors.load()
self.vectors = vectors

# datas for classification
X_train = [['goodbye', 'please', 'love'],[],['no', 'yes', 'friend', 'family', 'happy'],['a', 'the'],['beautiful','small','a']]
y_train = [0,0,1,0,1]
X_test = [['goodbye', 'family', 'friend'],['beautiful', 'small'],['please','happy','love']]
y_test = [0,1,0]

self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test

def tearDown(self):
"""Tear down test fixtures, if any."""
Expand All @@ -43,6 +54,20 @@ def test_similarity_MEN_WS353_SCWS(self):
self.assertGreater(round(res["MEN"],2), 0.38)
self.assertGreater(round(res["WS353"],2), 0.40)
self.assertGreater(round(res["SCWS"],2), 0.38)

def test_vectorize(self):
X, y = vectorizer(self.vectors, self.X_train, y=self.y_train)
self.assertTrue(len(X) == len(y))

def test_clf_fit_and_score(self):
X_train, y_train = vectorizer(self.vectors, self.X_train, y=self.y_train)
X_test, y_test = vectorizer(self.vectors, self.X_test, y=self.y_test)

clf = clf_fit(X_train, y_train)
score = clf_score(clf, X_test, y_test)

self.assertTrue(score <= 1 and score >= 0)


if __name__ == '__main__':
unittest.main()

0 comments on commit 47e9c2b

Please sign in to comment.