From f4344582ba28983bf3892d08e11236f090f5bf92 Mon Sep 17 00:00:00 2001 From: fwang1 <desperado.wf@gmail.com> Date: Sun, 10 Apr 2016 01:13:25 -0700 Subject: [PATCH] [SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer ## What changes were proposed in this pull request? Replace sortBy() with top() to calculate the top N frequent words as dictionary. ## How was this patch tested? existing unit tests. The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"... Author: fwang1 <desperado.wf@gmail.com> Closes #12265 from lionelfeng/master. --- .../apache/spark/ml/feature/CountVectorizer.scala | 14 ++++---------- .../spark/ml/feature/CountVectorizerSuite.scala | 7 ++++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index f1be971a6a..00abbbe29c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index ff0de06e27..7641e3b8cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => -- GitLab