diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index f1be971a6ae94d8508f0d3d55c0f9c70df030ad4..00abbbe29c0d07431b3bace41e777e5958e4fccb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String)
       (word, count)
     }.cache()
     val fullVocabSize = wordCounts.count()
-    val vocab: Array[String] = {
-      val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
-        // Use all terms
-        wordCounts.collect().sortBy(-_._2)
-      } else {
-        // Sort terms to select vocab
-        wordCounts.sortBy(_._2, ascending = false).take(vocSize)
-      }
-      tmpSortedWC.map(_._1)
-    }
+
+    val vocab = wordCounts
+      .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+      .map(_._1)
 
     require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
     copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index ff0de06e27d014d959ec1099037b58d02696dede..7641e3b8cf668b356508a877d9ebab8acfefba3e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       (0, split("a b c d e"),
         Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
       (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
-      (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
-      (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+      (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
+      (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
+      (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
     ).toDF("id", "words", "expected")
     val cv = new CountVectorizer()
       .setInputCol("words")
       .setOutputCol("features")
       .fit(df)
-    assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+    assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
 
     cv.transform(df).select("features", "expected").collect().foreach {
       case Row(features: Vector, expected: Vector) =>