From d6ae7d4637d23c57c4eeab79d1177216f380ec9c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Fri, 15 Apr 2016 11:50:21 -0700 Subject: [PATCH] [SPARK-14665][ML][PYTHON] Fixed bug with StopWordsRemover default stopwords ## What changes were proposed in this pull request? The default stopwords were a Java object. They are no longer. ## How was this patch tested? Unit test which failed before the fix Author: Joseph K. Bradley <joseph@databricks.com> Closes #12422 from jkbradley/pyspark-stopwords. --- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/tests.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 809a513316..0d8ef1297f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1765,7 +1765,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords - defaultStopWords = stopWordsObj.English() + defaultStopWords = list(stopWordsObj.English()) self._setDefault(stopWords=defaultStopWords, caseSensitive=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 86c0254a2b..85ad949c5a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -22,6 +22,7 @@ import array import sys if sys.version > '3': xrange = range + basestring = str try: import xmlrunner @@ -398,6 +399,8 @@ class FeatureTests(PySparkTestCase): self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["panda"]) + self.assertEqual(type(stopWordRemover.getStopWords()), list) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) -- GitLab