diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index dee898827f30fbe1bebd3fc19949ccfe6b1f623f..3241ebeb22c4256f99ce29bac4428a1f74e26ab7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging {
   private var numIterations = 1
   private var seed = Utils.random.nextLong()
   private var minCount = 5
+  private var maxSentenceLength = 1000
+
+  /**
+   * Sets the maximum length (in words) of each sentence in the input data.
+   * Any sentence longer than this threshold will be divided into chunks of
+   * up to `maxSentenceLength` size (default: 1000)
+   */
+  @Since("2.0.0")
+  def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
+    this.maxSentenceLength = maxSentenceLength
+    this
+  }
 
   /**
    * Sets vector size (default: 100).
@@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging {
   private val EXP_TABLE_SIZE = 1000
   private val MAX_EXP = 6
   private val MAX_CODE_LENGTH = 40
-  private val MAX_SENTENCE_LENGTH = 1000
 
   /** context words from [-window, window] */
   private var window = 5
@@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging {
   @transient private var vocab: Array[VocabWord] = null
   @transient private var vocabHash = mutable.HashMap.empty[String, Int]
 
-  private def learnVocab(words: RDD[String]): Unit = {
+  private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {
+    val words = dataset.flatMap(x => x)
+
     vocab = words.map(w => (w, 1))
       .reduceByKey(_ + _)
       .filter(_._2 >= minCount)
@@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging {
 
   /**
    * Computes the vector representation of each word in vocabulary.
-   * @param dataset an RDD of words
+   * @param dataset an RDD of sentences,
+   *                each sentence is expressed as an iterable collection of words
    * @return a Word2VecModel
    */
   @Since("1.1.0")
   def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
 
-    val words = dataset.flatMap(x => x)
-
-    learnVocab(words)
+    learnVocab(dataset)
 
     createBinaryTree()
 
@@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging {
     val expTable = sc.broadcast(createExpTable())
     val bcVocab = sc.broadcast(vocab)
     val bcVocabHash = sc.broadcast(vocabHash)
-
-    val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
-      new Iterator[Array[Int]] {
-        def hasNext: Boolean = iter.hasNext
-
-        def next(): Array[Int] = {
-          val sentence = ArrayBuilder.make[Int]
-          var sentenceLength = 0
-          while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
-            val word = bcVocabHash.value.get(iter.next())
-            word match {
-              case Some(w) =>
-                sentence += w
-                sentenceLength += 1
-              case None =>
-            }
-          }
-          sentence.result()
-        }
+    // each partition is a collection of sentences,
+    // will be translated into arrays of Index integer
+    val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
+      // Each sentence will map to 0 or more Array[Int]
+      sentenceIter.flatMap { sentence =>
+        // Sentence of words, some of which map to a word index
+        val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
+        // break wordIndexes into trunks of maxSentenceLength when has more
+        wordIndexes.grouped(maxSentenceLength).map(_.toArray)
       }
     }
 
@@ -477,15 +479,6 @@ class Word2VecModel private[spark] (
     this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
   }
 
-  private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
-    require(v1.length == v2.length, "Vectors should have the same length")
-    val n = v1.length
-    val norm1 = blas.snrm2(n, v1, 1)
-    val norm2 = blas.snrm2(n, v2, 1)
-    if (norm1 == 0 || norm2 == 0) return 0.0
-    blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2
-  }
-
   override protected def formatVersion = "1.0"
 
   @Since("1.4.0")
@@ -542,6 +535,7 @@ class Word2VecModel private[spark] (
     // Need not divide with the norm of the given vector since it is constant.
     val cosVec = cosineVec.map(_.toDouble)
     var ind = 0
+    val vecNorm = blas.snrm2(vectorSize, fVector, 1)
     while (ind < numWords) {
       val norm = wordVecNorms(ind)
       if (norm == 0.0) {
@@ -551,12 +545,17 @@ class Word2VecModel private[spark] (
       }
       ind += 1
     }
-    wordList.zip(cosVec)
+    var topResults = wordList.zip(cosVec)
       .toSeq
-      .sortBy(- _._2)
+      .sortBy(-_._2)
       .take(num + 1)
       .tail
-      .toArray
+    if (vecNorm != 0.0f) {
+      topResults = topResults.map { case (word, cosVal) =>
+        (word, cosVal / vecNorm)
+      }
+    }
+    topResults.toArray
   }
 
   /**
@@ -568,6 +567,7 @@ class Word2VecModel private[spark] (
       (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
     }
   }
+
 }
 
 @Since("1.4.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a73b565125668775c0585d144fed12dd48e1b9e7..f094c550e545a7a1e39db4e4c736e446b66d7a44 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -133,7 +133,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
       .setSeed(42L)
       .fit(docDF)
 
-    val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823)
+    val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
     val (synonyms, similarity) = model.findSynonyms("a", 2).map {
       case Row(w: String, sim: Double) => (w, sim)
     }.collect().unzip
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index d017a231886cbb6635daf9edefa0a4399c85467d..464c9446f2f3932b02180ecbd11dcf46d1bba76f 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1836,12 +1836,12 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
     +----+--------------------+
     ...
     >>> model.findSynonyms("a", 2).show()
-    +----+--------------------+
-    |word|          similarity|
-    +----+--------------------+
-    |   b| 0.16782984556103436|
-    |   c|-0.46761559092107646|
-    +----+--------------------+
+    +----+-------------------+
+    |word|         similarity|
+    +----+-------------------+
+    |   b| 0.2505344027513247|
+    |   c|-0.6980510075367647|
+    +----+-------------------+
     ...
     >>> model.transform(doc).head().model
     DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])