diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index f1be971a6ae94d8508f0d3d55c0f9c70df030ad4..00abbbe29c0d07431b3bace41e777e5958e4fccb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index ff0de06e27d014d959ec1099037b58d02696dede..7641e3b8cf668b356508a877d9ebab8acfefba3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) =>