Skip to content
Snippets Groups Projects
Commit f43a26ef authored by Yuhao Yang's avatar Yuhao Yang Committed by Joseph K. Bradley
Browse files

[SPARK-13629][ML] Add binary toggle Param to CountVectorizer

## What changes were proposed in this pull request?

This is a continued work for https://github.com/apache/spark/pull/11536#issuecomment-198511013,
containing some comment update and style adjustment.
jkbradley

## How was this patch tested?

unit tests.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #11830 from hhbyyh/cvToggle.
parent 54794113
No related branches found
No related tags found
No related merge requests found
...@@ -207,13 +207,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin ...@@ -207,13 +207,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
def setMinTF(value: Double): this.type = set(minTF, value) def setMinTF(value: Double): this.type = set(minTF, value)
/** /**
* Binary toggle to control the output vector values. * Binary toggle to control the output vector values.
* If True, all non zero counts are set to 1. This is useful for discrete probabilistic * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
* models that model binary events rather than integer counts * discrete probabilistic models that model binary events rather than integer counts.
* * Default: false
* Default: false * @group param
* @group param */
*/
val binary: BooleanParam = val binary: BooleanParam =
new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " + new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
"This is useful for discrete probabilistic models that model binary events rather " + "This is useful for discrete probabilistic models that model binary events rather " +
...@@ -248,17 +247,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin ...@@ -248,17 +247,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
} }
tokenCount += 1 tokenCount += 1
} }
val effectiveMinTF = if (minTf >= 1.0) { val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf
minTf
} else {
tokenCount * minTf
}
val effectiveCounts = if ($(binary)) { val effectiveCounts = if ($(binary)) {
termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
} } else {
else {
termCounts.filter(_._2 >= effectiveMinTF).toSeq termCounts.filter(_._2 >= effectiveMinTF).toSeq
} }
Vectors.sparse(dictBr.value.size, effectiveCounts) Vectors.sparse(dictBr.value.size, effectiveCounts)
} }
dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment