Skip to content
Snippets Groups Projects
Commit 076ec056 authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-9533] [PYSPARK] [ML] Add missing methods in Word2Vec ML

After https://github.com/apache/spark/pull/7263 it is pretty straightforward to Python wrappers.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7930 from MechCoder/spark-9533 and squashes the following commits:

1bea394 [MechCoder] make getVectors a lazy val
5522756 [MechCoder] [SPARK-9533] [PySpark] [ML] Add missing methods in Word2Vec ML
parent c5c6aded
No related branches found
No related tags found
No related merge requests found
......@@ -153,7 +153,7 @@ class Word2VecModel private[ml] (
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
*/
val getVectors: DataFrame = {
@transient lazy val getVectors: DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
......
......@@ -15,11 +15,16 @@
# limitations under the License.
#
import sys
if sys.version > '3':
basestring = str
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
......@@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
+----+--------------------+
| a|[-0.3511952459812...|
| b|[0.29077222943305...|
| c|[0.02315592765808...|
+----+--------------------+
...
>>> model.findSynonyms("a", 2).show()
+----+-------------------+
|word| similarity|
+----+-------------------+
| b|0.29255685145799626|
| c|-0.5414068302988307|
+----+-------------------+
...
>>> model.transform(doc).head().model
DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
"""
......@@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel):
Model fitted by Word2Vec.
"""
def getVectors(self):
"""
Returns the vector representation of the words as a dataframe
with two fields, word and vector.
"""
return self._call_java("getVectors")
def findSynonyms(self, word, num):
"""
Find "num" number of words closest in similarity to "word".
word can be a string or vector representation.
Returns a dataframe with two fields word and similarity (which
gives the cosine similarity).
"""
if not isinstance(word, basestring):
word = _convert_to_vector(word)
return self._call_java("findSynonyms", word, num)
@inherit_doc
class PCA(JavaEstimator, HasInputCol, HasOutputCol):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment