Skip to content
Snippets Groups Projects
Commit 1870dbaa authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[MLLIB] minor update to word2vec

very minor update Ishiihara

Author: Xiangrui Meng <meng@databricks.com>

Closes #2043 from mengxr/minor-w2v and squashes the following commits:

be649fd [Xiangrui Meng] remove map because we only need append
eccefcc [Xiangrui Meng] minor updates to word2vec
parent 8b9dc991
No related branches found
No related tags found
No related merge requests found
......@@ -30,11 +30,9 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
/**
* Entry in vocabulary
......@@ -285,9 +283,9 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var syn0Global =
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
......@@ -349,21 +347,21 @@ class Word2Vec extends Serializable with Logging {
}
val syn0Local = model._1
val syn1Local = model._2
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
var index = 0
while(index < vocabSize) {
if (syn0Modify(index) != 0) {
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
if (syn1Modify(index) != 0) {
synOut.update(index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
synOut += ((index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
index += 1
}
Iterator(synOut)
synOut.toIterator
}
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment