diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py b/python/pyspark/ml/tests/connect/test_parity_feature.py index d75146be41cf3..baa3e6e7e0df9 100644 --- a/python/pyspark/ml/tests/connect/test_parity_feature.py +++ b/python/pyspark/ml/tests/connect/test_parity_feature.py @@ -30,6 +30,10 @@ def test_count_vectorizer_from_vocab(self): def test_string_indexer_from_labels(self): super().test_string_indexer_from_labels() + @unittest.skip("Need to support.") + def test_stop_words_lengague_selection(self): + super().test_stop_words_lengague_selection() + if __name__ == "__main__": from pyspark.ml.tests.connect.test_parity_feature import * # noqa: F401 diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index edfda54ca83c4..f42f89b3014f9 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -896,13 +896,6 @@ def test_stop_words_remover_II(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) # with locale stopwords = ["BELKÄ°"] dataset = self.spark.createDataFrame([Row(input=["belki"])]) @@ -911,6 +904,15 @@ def test_stop_words_remover_II(self): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + def test_stop_words_language_selection(self): + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + def test_binarizer(self): b0 = Binarizer() self.assertListEqual(