diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 87c81e7b0bd2f0839f9e8c27191808c1e9bad946..3bf44ad7c44e36af30c9330258851962ccffb4a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.apache.spark.{HashPartitioner, Logging}
+
+import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
 
 /**
  *  Entry in vocabulary 
@@ -58,29 +59,63 @@ private case class VocabWord(
  * Efficient Estimation of Word Representations in Vector Space
  * and 
  * Distributed Representations of Words and Phrases and their Compositionality.
- * @param size vector dimension
- * @param startingAlpha initial learning rate
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
  */
 @Experimental
-class Word2Vec(
-    val size: Int,
-    val startingAlpha: Double,
-    val parallelism: Int,
-    val numIterations: Int) extends Serializable with Logging {
+class Word2Vec extends Serializable with Logging {
+
+  private var vectorSize = 100
+  private var startingAlpha = 0.025
+  private var numPartitions = 1
+  private var numIterations = 1
+  private var seed = Utils.random.nextLong()
+
+  /**
+   * Sets vector size (default: 100).
+   */
+  def setVectorSize(vectorSize: Int): this.type = {
+    this.vectorSize = vectorSize
+    this
+  }
+
+  /**
+   * Sets initial learning rate (default: 0.025).
+   */
+  def setLearningRate(learningRate: Double): this.type = {
+    this.startingAlpha = learningRate
+    this
+  }
 
   /**
-   * Word2Vec with a single thread.
+   * Sets number of partitions (default: 1). Use a small number for accuracy.
    */
-  def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
+  def setNumPartitions(numPartitions: Int): this.type = {
+    require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
+    this.numPartitions = numPartitions
+    this
+  }
+
+  /**
+   * Sets number of iterations (default: 1), which should be smaller than or equal to number of
+   * partitions.
+   */
+  def setNumIterations(numIterations: Int): this.type = {
+    this.numIterations = numIterations
+    this
+  }
+
+  /**
+   * Sets random seed (default: a random long integer).
+   */
+  def setSeed(seed: Long): this.type = {
+    this.seed = seed
+    this
+  }
 
   private val EXP_TABLE_SIZE = 1000
   private val MAX_EXP = 6
   private val MAX_CODE_LENGTH = 40
   private val MAX_SENTENCE_LENGTH = 1000
-  private val layer1Size = size 
-  private val modelPartitionNum = 100
+  private val layer1Size = vectorSize
 
   /** context words from [-window, window] */
   private val window = 5
@@ -94,12 +129,12 @@ class Word2Vec(
   private var vocabHash = mutable.HashMap.empty[String, Int]
   private var alpha = startingAlpha
 
-  private def learnVocab(words:RDD[String]): Unit = {
+  private def learnVocab(words: RDD[String]): Unit = {
     vocab = words.map(w => (w, 1))
       .reduceByKey(_ + _)
       .map(x => VocabWord(
-        x._1, 
-        x._2, 
+        x._1,
+        x._2,
         new Array[Int](MAX_CODE_LENGTH), 
         new Array[Int](MAX_CODE_LENGTH), 
         0))
@@ -245,23 +280,24 @@ class Word2Vec(
       }
     }
     
-    val newSentences = sentences.repartition(parallelism).cache()
+    val newSentences = sentences.repartition(numPartitions).cache()
+    val initRandom = new XORShiftRandom(seed)
     var syn0Global =
-      Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
+      Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
     var syn1Global = new Array[Float](vocabSize * layer1Size)
-    
-    for(iter <- 1 to numIterations) {
-      val (aggSyn0, aggSyn1, _, _) =
-        // TODO: broadcast temp instead of serializing it directly
-        // or initialize the model in each executor
-        newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
-          seqOp = (c, v) => (c, v) match { 
+
+    for (k <- 1 to numIterations) {
+      val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
+        val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+        val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
           case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
             var lwc = lastWordCount
-            var wc = wordCount 
+            var wc = wordCount
             if (wordCount - lastWordCount > 10000) {
               lwc = wordCount
-              alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
+              // TODO: discount by iteration?
+              alpha =
+                startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
               if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
               logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
             }
@@ -269,8 +305,7 @@ class Word2Vec(
             var pos = 0
             while (pos < sentence.size) {
               val word = sentence(pos)
-              // TODO: fix random seed
-              val b = Random.nextInt(window)
+              val b = random.nextInt(window)
               // Train Skip-gram
               var a = b
               while (a < window * 2 + 1 - b) {
@@ -280,7 +315,7 @@ class Word2Vec(
                     val lastWord = sentence(c)
                     val l1 = lastWord * layer1Size
                     val neu1e = new Array[Float](layer1Size)
-                    // Hierarchical softmax 
+                    // Hierarchical softmax
                     var d = 0
                     while (d < bcVocab.value(word).codeLen) {
                       val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +338,44 @@ class Word2Vec(
               pos += 1
             }
             (syn0, syn1, lwc, wc)
-          },
-          combOp = (c1, c2) => (c1, c2) match { 
-            case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
-              val n = syn0_1.length
-              val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
-              val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
-              blas.sscal(n, weight1, syn0_1, 1)
-              blas.sscal(n, weight1, syn1_1, 1)
-              blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
-              blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
-              (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
-          })
+        }
+        Iterator(model)
+      }
+      val (aggSyn0, aggSyn1, _, _) =
+        partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
+          val n = syn0_1.length
+          val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
+          val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
+          blas.sscal(n, weight1, syn0_1, 1)
+          blas.sscal(n, weight1, syn1_1, 1)
+          blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
+          blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
+          (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+        }
       syn0Global = aggSyn0
       syn1Global = aggSyn1
     }
     newSentences.unpersist()
     
-    val wordMap = new Array[(String, Array[Float])](vocabSize)
+    val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
     var i = 0
     while (i < vocabSize) {
       val word = bcVocab.value(i).word
       val vector = new Array[Float](layer1Size)
       Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
-      wordMap(i) = (word, vector)
+      word2VecMap += word -> vector
       i += 1
     }
-    val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
-      .partitionBy(new HashPartitioner(modelPartitionNum))
-      .persist(StorageLevel.MEMORY_AND_DISK)
-    
-    new Word2VecModel(modelRDD)
+
+    new Word2VecModel(word2VecMap.toMap)
   }
 }
 
 /**
 * Word2Vec model
-*/
-class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
+ */
+class Word2VecModel private[mllib] (
+    private val model: Map[String, Array[Float]]) extends Serializable {
 
   private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
     require(v1.length == v2.length, "Vectors should have the same length")
@@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
    * @return vector representation of word
    */
   def transform(word: String): Vector = {
-    val result = model.lookup(word) 
-    if (result.isEmpty) {
-      throw new IllegalStateException(s"$word not in vocabulary")
+    model.get(word) match {
+      case Some(vec) =>
+        Vectors.dense(vec.map(_.toDouble))
+      case None =>
+        throw new IllegalStateException(s"$word not in vocabulary")
     }
-    else Vectors.dense(result(0).map(_.toDouble))
   }
   
   /**
@@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
    */
   def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
     require(num > 0, "Number of similar words should > 0")
-    val topK = model.map { case(w, vec) => 
-      (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
-    .sortByKey(ascending = false)
-    .take(num + 1)
-    .map(_.swap)
-    .tail
-    
-    topK
-  }
-}
-
-object Word2Vec{
-  /**
-   * Train Word2Vec model
-   * @param input RDD of words
-   * @param size vector dimension
-   * @param startingAlpha initial learning rate
-   * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
-   * @param numIterations number of iterations, should be smaller than or equal to parallelism
-   * @return Word2Vec model
-   */
-  def train[S <: Iterable[String]](
-    input: RDD[S],
-    size: Int,
-    startingAlpha: Double,
-    parallelism: Int = 1,
-    numIterations:Int = 1): Word2VecModel = {
-    new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
+    // TODO: optimize top-k
+    val fVector = vector.toArray.map(_.toFloat)
+    model.mapValues(vec => cosineSimilarity(fVector, vec))
+      .toSeq
+      .sortBy(- _._2)
+      .take(num + 1)
+      .tail
+      .toArray
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index b5db39b68a223b171ee86cf58bd65c4f60d3dfec..e34335d89eb755476239d8cb790aded6e3446036 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
     val localDoc = Seq(sentence, sentence)
     val doc = sc.parallelize(localDoc)
       .map(line => line.split(" ").toSeq)
-    val size = 10
-    val startingAlpha = 0.025
-    val window = 2 
-    val minCount = 2
-    val num = 2
-
-    val model = Word2Vec.train(doc, size, startingAlpha)
+    val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
     val syms = model.findSynonyms("a", 2)
-    assert(syms.length == num)
+    assert(syms.length == 2)
     assert(syms(0)._1 == "b")
     assert(syms(1)._1 == "c")
   }
 
-
   test("Word2VecModel") {
     val num = 2
-    val localModel = Seq(
+    val word2VecMap = Map(
       ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
       ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
       ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
       ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
     )
-    val model = new Word2VecModel(sc.parallelize(localModel, 2))
+    val model = new Word2VecModel(word2VecMap)
     val syms = model.findSynonyms("china", num)
     assert(syms.length == num)
     assert(syms(0)._1 == "taiwan")