Skip to content

Commit

Permalink
[SPARK-51078][PYTHOM][ML][TESTS][FOLLOW-UP] Separate language selecti…
Browse files Browse the repository at this point in the history
…on test for StopWordsRemover

### What changes were proposed in this pull request?

This PR is a followup of apache#49789 that separate language selection test for connect ML and skip it.

### Why are the changes needed?

To fix the build in https://github.com/apache/spark/actions/runs/13142915838/job/36674121214

### Does this PR introduce _any_ user-facing change?

No, test-only.

### How was this patch tested?

CI in this PR should test the separation. For Connect, i will monitor the scheduled build.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#49798 from HyukjinKwon/SPARK-510782.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Feb 5, 2025
1 parent 1fc9d7d commit 198435e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_count_vectorizer_from_vocab(self):
def test_string_indexer_from_labels(self):
super().test_string_indexer_from_labels()

@unittest.skip("Need to support.")
def test_stop_words_lengague_selection(self):
super().test_stop_words_lengague_selection()


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_feature import * # noqa: F401
Expand Down
16 changes: 9 additions & 7 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,13 +896,6 @@ def test_stop_words_remover_II(self):
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
# with language selection
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
# with locale
stopwords = ["BELKİ"]
dataset = self.spark.createDataFrame([Row(input=["belki"])])
Expand All @@ -911,6 +904,15 @@ def test_stop_words_remover_II(self):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_stop_words_language_selection(self):
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down

0 comments on commit 198435e

Please sign in to comment.