diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index b4f46cef798dd5e325a98cda1c30da679703e459..29acc3eb5865f052200973b52ed5524b6daac379 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -153,7 +153,7 @@ class Word2VecModel private[ml] ( * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. */ - val getVectors: DataFrame = { + @transient lazy val getVectors: DataFrame = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3f04c41ac5ab62b666e7a296655b1e2c0f867e72..cb4dfa21298ceed377238d11ecc5e0ddbfd6d43a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,11 +15,16 @@ # limitations under the License. # +import sys +if sys.version > '3': + basestring = str + from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', @@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[-0.3511952459812...| + | b|[0.29077222943305...| + | c|[0.02315592765808...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b|0.29255685145799626| + | c|-0.5414068302988307| + +----+-------------------+ + ... >>> model.transform(doc).head().model DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ @@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel): Model fitted by Word2Vec. """ + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol):