diff --git a/.gitignore b/.gitignore index 51d5539..671cbae 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,3 @@ /venv/ /models/ .idea -# TODO: fix it -/tests/data/fasttext.model.trainables.vectors_ngrams_lockf.npy -/tests/data/fasttext.model.wv.vectors_ngrams.npy diff --git a/tests/test_train.py b/tests/test_train.py index 1632568..9d11c4c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -53,14 +53,13 @@ criterion=criterion, optimizer=optimizer, device=device, - n_epoch=10, + n_epoch=5, verbose=False, ) class TestTrain(unittest.TestCase): - # TODO: fix it - not always True def test_val_metrics(self): val_metrics = validate_loop( @@ -72,7 +71,7 @@ def test_val_metrics(self): ) for metric_name, metric_list in val_metrics.items(): - if metric_name.startswith('f1'): + if not metric_name.startswith('loss'): self.assertTrue(np.mean(metric_list) == 1.0)