Skip to content
Snippets Groups Projects
Commit f4344582 authored by fwang1's avatar fwang1 Committed by Xiangrui Meng
Browse files

[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as...

[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer

## What changes were proposed in this pull request?

Replace sortBy() with top() to calculate the top N frequent words as dictionary.

## How was this patch tested?
existing unit tests.  The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"...

Author: fwang1 <desperado.wf@gmail.com>

Closes #12265 from lionelfeng/master.
parent 22014e6f
No related branches found
No related tags found
No related merge requests found
...@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) ...@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String)
(word, count) (word, count)
}.cache() }.cache()
val fullVocabSize = wordCounts.count() val fullVocabSize = wordCounts.count()
val vocab: Array[String] = {
val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { val vocab = wordCounts
// Use all terms .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
wordCounts.collect().sortBy(-_._2) .map(_._1)
} else {
// Sort terms to select vocab
wordCounts.sortBy(_._2, ascending = false).take(vocSize)
}
tmpSortedWC.map(_._1)
}
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
......
...@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(0, split("a b c d e"), (0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
(2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
(3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
(4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
).toDF("id", "words", "expected") ).toDF("id", "words", "expected")
val cv = new CountVectorizer() val cv = new CountVectorizer()
.setInputCol("words") .setInputCol("words")
.setOutputCol("features") .setOutputCol("features")
.fit(df) .fit(df)
assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach { cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) => case Row(features: Vector, expected: Vector) =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment