diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 41a27f6208d1b29f32f26ebfd49012a1229a04a5..1511ae6dda4ed68718066d1de3d280aa3146b6cf 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -82,6 +82,21 @@ tf.cache() val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} + +MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature +can be used by passing the `minDocFreq` value to the IDF constructor. + +{% highlight scala %} +import org.apache.spark.mllib.feature.IDF + +// ... continue from the previous example +tf.cache() +val idf = new IDF(minDocFreq = 2).fit(tf) +val tfidf: RDD[Vector] = idf.transform(tf) +{% endhighlight %} + + </div> </div> diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index d40d5553c1d219b1bff0981c6f396a6c26642dc5..720bb70b08dbf7a5e3d40de34326d20db5021e01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -30,9 +30,18 @@ import org.apache.spark.rdd.RDD * Inverse document frequency (IDF). * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total * number of documents and `d(t)` is the number of documents that contain term `t`. + * + * This implementation supports filtering out terms which do not appear in a minimum number + * of documents (controlled by the variable `minDocFreq`). For terms that are not in + * at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + * + * @param minDocFreq minimum of documents in which a term + * should appear for filtering */ @Experimental -class IDF { +class IDF(val minDocFreq: Int) { + + def this() = this(0) // TODO: Allow different IDF formulations. @@ -41,7 +50,8 @@ class IDF { * @param dataset an RDD of term frequency vectors */ def fit(dataset: RDD[Vector]): IDFModel = { - val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( + minDocFreq = minDocFreq))( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() @@ -60,13 +70,16 @@ class IDF { private object IDF { /** Document frequency aggregator. */ - class DocumentFrequencyAggregator extends Serializable { + class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable { /** number of documents */ private var m = 0L /** document frequency vector */ private var df: BDV[Long] = _ + + def this() = this(0) + /** Adds a new document. */ def add(doc: Vector): this.type = { if (isEmpty) { @@ -123,7 +136,18 @@ private object IDF { val inv = new Array[Double](n) var j = 0 while (j < n) { - inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + /* + * If the term is not present in the minimum + * number of documents, set IDF to 0. This + * will cause multiplication in IDFModel to + * set TF-IDF to 0. + * + * Since arrays are initialized to 0 by default, + * we just omit changing those entries. + */ + if(df(j) >= minDocFreq) { + inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) + } j += 1 } Vectors.dense(inv) @@ -140,6 +164,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. + * + * If `minDocFreq` was set for the IDF calculation, + * the terms which occur in fewer than `minDocFreq` + * documents will have an entry of 0. + * * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index e8d99f4ae43aecb7b04cd5b09d5b9f5854ec41ec..064263e02cd11cf2b418a494585af7428c426a63 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -63,4 +63,24 @@ public class JavaTfIdfSuite implements Serializable { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } + + @Test + public void tfIdfMinimumDocumentFrequency() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD<Vector> termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(2); + JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List<Vector> localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 53d9c0c640b980d8523967940a99268432741a04..43974f84e3ca8fb7376a7c730c263848107ba2f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -38,7 +38,7 @@ class IDFSuite extends FunSuite with LocalSparkContext { val idf = new IDF val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => - math.log((m.toDouble + 1.0) / (x + 1.0)) + math.log((m + 1.0) / (x + 1.0)) }) assert(model.idf ~== expected absTol 1e-12) val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() @@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext { assert(tfidf2.indices === Array(1)) assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) } + + test("idf minimum document frequency filtering") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF(minDocFreq = 1) + val model = idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) { + math.log((m + 1.0) / (x + 1.0)) + } else { + 0 + } + }) + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } + }