Skip to content
Snippets Groups Projects
Commit 27524a3a authored by Yuming Wang's avatar Yuming Wang Committed by Xiangrui Meng
Browse files

[SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow

org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every sentence.

Author: Yuming Wang <q79969786@gmail.com>
Author: yuming.wang <q79969786@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #9592 from 979969786/master.
parent 1510c527
No related branches found
No related tags found
No related merge requests found
......@@ -17,18 +17,16 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
......@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
@Experimental
class Word2VecModel private[ml] (
override val uid: String,
wordVectors: feature.Word2VecModel)
@transient wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
......@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
} else {
val cum = Vectors.zeros($(vectorSize))
val model = bWordVectors.value.getVectors
for (word <- sentence) {
if (model.contains(word)) {
axpy(1.0, bWordVectors.value.transform(word), cum)
} else {
// pass words which not belong to model
val sum = Vectors.zeros(d)
sentence.foreach { word =>
bVectors.value.get(word).foreach { v =>
BLAS.axpy(1.0, v, sum)
}
}
scal(1.0 / sentence.size, cum)
cum
BLAS.scal(1.0 / sentence.size, sum)
sum
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment