From 27524a3a9ccee6fbe56149180ebfb3f74e0957e7 Mon Sep 17 00:00:00 2001
From: Yuming Wang <q79969786@gmail.com>
Date: Wed, 11 Nov 2015 09:43:26 -0800
Subject: [PATCH] [SPARK-11626][ML] ml.feature.Word2Vec.transform() function
 very slow

org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every sentence.

Author: Yuming Wang <q79969786@gmail.com>
Author: yuming.wang <q79969786@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #9592 from 979969786/master.
---
 .../apache/spark/ml/feature/Word2Vec.scala    | 34 +++++++++----------
 1 file changed, 16 insertions(+), 18 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 9edab3af91..5c64cb09d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,18 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
 import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.types._
 
 /**
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
 @Experimental
 class Word2VecModel private[ml] (
     override val uid: String,
-    wordVectors: feature.Word2VecModel)
+    @transient wordVectors: feature.Word2VecModel)
   extends Model[Word2VecModel] with Word2VecBase {
 
-
   /**
    * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
    * and the vector the DenseVector that it is mapped to.
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
    */
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
+    val vectors = wordVectors.getVectors
+      .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
+      .map(identity) // mapValues doesn't return a serializable map (SI-7005)
+    val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+    val d = $(vectorSize)
     val word2Vec = udf { sentence: Seq[String] =>
       if (sentence.size == 0) {
-        Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
+        Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
       } else {
-        val cum = Vectors.zeros($(vectorSize))
-        val model = bWordVectors.value.getVectors
-        for (word <- sentence) {
-          if (model.contains(word)) {
-            axpy(1.0, bWordVectors.value.transform(word), cum)
-          } else {
-            // pass words which not belong to model
+        val sum = Vectors.zeros(d)
+        sentence.foreach { word =>
+          bVectors.value.get(word).foreach { v =>
+            BLAS.axpy(1.0, v, sum)
           }
         }
-        scal(1.0 / sentence.size, cum)
-        cum
+        BLAS.scal(1.0 / sentence.size, sum)
+        sum
       }
     }
     dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
-- 
GitLab