From d6ae7d4637d23c57c4eeab79d1177216f380ec9c Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Fri, 15 Apr 2016 11:50:21 -0700
Subject: [PATCH] [SPARK-14665][ML][PYTHON] Fixed bug with StopWordsRemover
 default stopwords

## What changes were proposed in this pull request?

The default stopwords were a Java object.  They are no longer.

## How was this patch tested?

Unit test which failed before the fix

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12422 from jkbradley/pyspark-stopwords.
---
 python/pyspark/ml/feature.py | 2 +-
 python/pyspark/ml/tests.py   | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 809a513316..0d8ef1297f 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1765,7 +1765,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
                                             self.uid)
         stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
-        defaultStopWords = stopWordsObj.English()
+        defaultStopWords = list(stopWordsObj.English())
         self._setDefault(stopWords=defaultStopWords, caseSensitive=False)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 86c0254a2b..85ad949c5a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -22,6 +22,7 @@ import array
 import sys
 if sys.version > '3':
     xrange = range
+    basestring = str
 
 try:
     import xmlrunner
@@ -398,6 +399,8 @@ class FeatureTests(PySparkTestCase):
         self.assertEqual(stopWordRemover.getInputCol(), "input")
         transformedDF = stopWordRemover.transform(dataset)
         self.assertEqual(transformedDF.head().output, ["panda"])
+        self.assertEqual(type(stopWordRemover.getStopWords()), list)
+        self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
         # Custom
         stopwords = ["panda"]
         stopWordRemover.setStopWords(stopwords)
-- 
GitLab