Skip to content
Snippets Groups Projects
Commit 425bcf6d authored by Bryan Cutler's avatar Bryan Cutler Committed by Nick Pentreath
Browse files

[SPARK-13963][ML] Adding binary toggle param to HashingTF

## What changes were proposed in this pull request?
Adding binary toggle parameter to ml.feature.HashingTF, as well as mllib.feature.HashingTF since the former wraps this functionality.  This parameter, if true, will set non-zero valued term counts to 1 to transform term count features to binary values that are well suited for discrete probability models.

## How was this patch tested?
Added unit tests for ML and MLlib

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #11832 from BryanCutler/binary-param-HashingTF-SPARK-13963.
parent 83775bc7
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
......@@ -52,7 +52,18 @@ class HashingTF(override val uid: String)
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))
setDefault(numFeatures -> (1 << 18))
/**
* Binary toggle to control term frequency counts.
* If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
* models that model binary events rather than integer counts.
* (default = false)
* @group param
*/
val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " +
"This is useful for discrete probabilistic models that model binary events rather " +
"than integer counts")
setDefault(numFeatures -> (1 << 18), binary -> false)
/** @group getParam */
def getNumFeatures: Int = $(numFeatures)
......@@ -60,9 +71,15 @@ class HashingTF(override val uid: String)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
/** @group getParam */
def getBinary: Boolean = $(binary)
/** @group setParam */
def setBinary(value: Boolean): this.type = set(binary, value)
override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val hashingTF = new feature.HashingTF($(numFeatures))
val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
......
......@@ -36,11 +36,23 @@ import org.apache.spark.util.Utils
@Since("1.1.0")
class HashingTF(val numFeatures: Int) extends Serializable {
private var binary = false
/**
*/
@Since("1.1.0")
def this() = this(1 << 20)
/**
* If true, term frequency vector will be binary such that non-zero term counts will be set to 1
* (default: false)
*/
@Since("2.0.0")
def setBinary(value: Boolean): this.type = {
binary = value
this
}
/**
* Returns the index of the input term.
*/
......@@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable {
@Since("1.1.0")
def transform(document: Iterable[_]): Vector = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
document.foreach { term =>
val i = indexOf(term)
termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
termFrequencies.put(i, setTF(i))
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
......
......@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
def idx: Any => Int = featureIdx(n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
test("applying binary term freqs") {
val df = sqlContext.createDataFrame(Seq(
(0, "a a b c c c".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(n)
.setBinary(true)
val output = hashingTF.transform(df)
val features = output.select("features").first().getAs[Vector](0)
def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features
val expected = Vectors.sparse(n,
Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
......@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setNumFeatures(10)
testDefaultReadWrite(t)
}
private def featureIdx(numFeatures: Int)(term: Any): Int = {
Utils.nonNegativeMod(term.##, numFeatures)
}
}
......@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
......@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
test("applying binary term freqs") {
val hashingTF = new HashingTF(100).setBinary(true)
val doc = "a a b c c c".split(" ")
val n = hashingTF.numFeatures
val expected = Vectors.sparse(n, Seq(
(hashingTF.indexOf("a"), 1.0),
(hashingTF.indexOf("b"), 1.0),
(hashingTF.indexOf("c"), 1.0)))
assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment