diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index ecd49ea2ff533156972ac17d49f9c34e2490bbac..d2ae62b482aff1fc9e4ea5ce8a6ed063f3388ebc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
 
 /**
  *  Entry in vocabulary 
@@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
     var syn0Global =
       Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
     var syn1Global = new Array[Float](vocabSize * vectorSize)
-
     var alpha = startingAlpha
     for (k <- 1 to numIterations) {
       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+        val syn0Modify = new Array[Int](vocabSize)
+        val syn1Modify = new Array[Int](vocabSize)
         val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
           case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
             var lwc = lastWordCount
@@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
                     // Hierarchical softmax
                     var d = 0
                     while (d < bcVocab.value(word).codeLen) {
-                      val l2 = bcVocab.value(word).point(d) * vectorSize
+                      val inner = bcVocab.value(word).point(d)
+                      val l2 = inner * vectorSize
                       // Propagate hidden -> output
                       var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
                       if (f > -MAX_EXP && f < MAX_EXP) {
@@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
                         val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
                         blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
                         blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
+                        syn1Modify(inner) += 1
                       }
                       d += 1
                     }
                     blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+                    syn0Modify(lastWord) += 1
                   }
                 }
                 a += 1
@@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
             }
             (syn0, syn1, lwc, wc)
         }
-        Iterator(model)
+        val syn0Local = model._1
+        val syn1Local = model._2
+        val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
+        var index = 0
+        while(index < vocabSize) {
+          if (syn0Modify(index) != 0) {
+            synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
+          }
+          if (syn1Modify(index) != 0) {
+            synOut.update(index + vocabSize,
+              syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
+          }
+          index += 1
+        }
+        Iterator(synOut)
       }
-      val (aggSyn0, aggSyn1, _, _) =
-        partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
-          val n = syn0_1.length
-          val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
-          val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
-          blas.sscal(n, weight1, syn0_1, 1)
-          blas.sscal(n, weight1, syn1_1, 1)
-          blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
-          blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
-          (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+      val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
+          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+          v1
+      }.collect()
+      var i = 0
+      while (i < synAgg.length) {
+        val index = synAgg(i)._1
+        if (index < vocabSize) {
+          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
+        } else {
+          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
         }
-      syn0Global = aggSyn0
-      syn1Global = aggSyn1
+        i += 1
+      }
     }
     newSentences.unpersist()