From 31c74fec24ae3bc8b9eb4ecd90896de459c3cc22 Mon Sep 17 00:00:00 2001 From: Xin Ren <iamshrek@126.com> Date: Fri, 8 Sep 2017 12:09:00 -0700 Subject: [PATCH] [SPARK-19866][ML][PYSPARK] Add local version of Word2Vec findSynonyms for spark.ml: Python API https://issues.apache.org/jira/browse/SPARK-19866 ## What changes were proposed in this pull request? Add Python API for findSynonymsArray matching Scala API. ## How was this patch tested? Manual test `./python/run-tests --python-executables=python2.7 --modules=pyspark-ml` Author: Xin Ren <iamshrek@126.com> Author: Xin Ren <renxin.ubc@gmail.com> Author: Xin Ren <keypointt@users.noreply.github.com> Closes #17451 from keypointt/SPARK-19866. --- .../org/apache/spark/ml/feature/Word2Vec.scala | 2 +- python/pyspark/ml/feature.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index d4c8e4b361..f6095e26f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -229,7 +229,7 @@ class Word2VecModel private[ml] ( * Find "num" number of words closest in similarity to the given word, not * including the word itself. * @return a dataframe with columns "word" and "similarity" of the word and the cosine - * similarities between the synonyms and the given word vector. + * similarities between the synonyms and the given word. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 050537b811..232ae3ef41 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2751,6 +2751,8 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has | c|[-0.3794820010662...| +----+--------------------+ ... + >>> model.findSynonymsArray("a", 2) + [(u'b', 0.25053444504737854), (u'c', -0.6980510950088501)] >>> from pyspark.sql.functions import format_number as fmt >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show() +----+----------+ @@ -2927,6 +2929,19 @@ class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): word = _convert_to_vector(word) return self._call_java("findSynonyms", word, num) + @since("2.3.0") + def findSynonymsArray(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns an array with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + tuples = self._java_obj.findSynonymsArray(word, num) + return list(map(lambda st: (st._1(), st._2()), list(tuples))) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): -- GitLab