From 31c74fec24ae3bc8b9eb4ecd90896de459c3cc22 Mon Sep 17 00:00:00 2001
From: Xin Ren <iamshrek@126.com>
Date: Fri, 8 Sep 2017 12:09:00 -0700
Subject: [PATCH] [SPARK-19866][ML][PYSPARK] Add local version of Word2Vec
 findSynonyms for spark.ml: Python API

https://issues.apache.org/jira/browse/SPARK-19866

## What changes were proposed in this pull request?

Add Python API for findSynonymsArray matching Scala API.

## How was this patch tested?

Manual test
`./python/run-tests --python-executables=python2.7 --modules=pyspark-ml`

Author: Xin Ren <iamshrek@126.com>
Author: Xin Ren <renxin.ubc@gmail.com>
Author: Xin Ren <keypointt@users.noreply.github.com>

Closes #17451 from keypointt/SPARK-19866.
---
 .../org/apache/spark/ml/feature/Word2Vec.scala    |  2 +-
 python/pyspark/ml/feature.py                      | 15 +++++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index d4c8e4b361..f6095e26f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -229,7 +229,7 @@ class Word2VecModel private[ml] (
    * Find "num" number of words closest in similarity to the given word, not
    * including the word itself.
    * @return a dataframe with columns "word" and "similarity" of the word and the cosine
-   * similarities between the synonyms and the given word vector.
+   * similarities between the synonyms and the given word.
    */
   @Since("1.5.0")
   def findSynonyms(word: String, num: Int): DataFrame = {
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 050537b811..232ae3ef41 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2751,6 +2751,8 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
     |   c|[-0.3794820010662...|
     +----+--------------------+
     ...
+    >>> model.findSynonymsArray("a", 2)
+    [(u'b', 0.25053444504737854), (u'c', -0.6980510950088501)]
     >>> from pyspark.sql.functions import format_number as fmt
     >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show()
     +----+----------+
@@ -2927,6 +2929,19 @@ class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
             word = _convert_to_vector(word)
         return self._call_java("findSynonyms", word, num)
 
+    @since("2.3.0")
+    def findSynonymsArray(self, word, num):
+        """
+        Find "num" number of words closest in similarity to "word".
+        word can be a string or vector representation.
+        Returns an array with two fields word and similarity (which
+        gives the cosine similarity).
+        """
+        if not isinstance(word, basestring):
+            word = _convert_to_vector(word)
+        tuples = self._java_obj.findSynonymsArray(word, num)
+        return list(map(lambda st: (st._1(), st._2()), list(tuples)))
+
 
 @inherit_doc
 class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
-- 
GitLab