Skip to content
Snippets Groups Projects
Commit cbeb006f authored by seddonm1's avatar seddonm1 Committed by Xiangrui Meng
Browse files

[SPARK-13097][ML] Binarizer allowing Double AND Vector input types

This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type.

A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image).

This contribution is my original work and I license the work to the project under the project's open source license.

viirya mengxr

Author: seddonm1 <seddonm1@gmail.com>

Closes #10976 from seddonm1/master.
parent adb54836
No related branches found
No related tags found
No related merge requests found
...@@ -17,15 +17,18 @@ ...@@ -17,15 +17,18 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.types._
/** /**
* :: Experimental :: * :: Experimental ::
...@@ -62,28 +65,53 @@ final class Binarizer(override val uid: String) ...@@ -62,28 +65,53 @@ final class Binarizer(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true) val outputSchema = transformSchema(dataset.schema, logging = true)
val schema = dataset.schema
val inputType = schema($(inputCol)).dataType
val td = $(threshold) val td = $(threshold)
val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
val outputColName = $(outputCol) val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() val binarizerVector = udf { (data: Vector) =>
dataset.select(col("*"), val indices = ArrayBuilder.make[Int]
binarizer(col($(inputCol))).as(outputColName, metadata)) val values = ArrayBuilder.make[Double]
data.foreachActive { (index, value) =>
if (value > td) {
indices += index
values += 1.0
}
}
Vectors.sparse(data.size, indices.result(), values.result()).compressed
}
val metadata = outputSchema($(outputCol)).metadata
inputType match {
case DoubleType =>
dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
case _: VectorUDT =>
dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
}
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams() val inputType = schema($(inputCol)).dataType
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
val outputColName = $(outputCol) val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName), val outCol: StructField = inputType match {
s"Output column $outputColName already exists.") case DoubleType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
val attr = BinaryAttribute.defaultAttr.withName(outputColName) case _: VectorUDT =>
val outputFields = inputFields :+ attr.toStructField() new StructField(outputColName, new VectorUDT, true)
StructType(outputFields) case other =>
throw new IllegalArgumentException(s"Data type $other is not supported.")
}
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
StructType(schema.fields :+ outCol)
} }
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
......
...@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature ...@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
...@@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau ...@@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
} }
} }
test("Binarize vector of continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x == y, "The feature value is not correct after binarization.")
}
}
test("Binarize vector of continuous features with setter") {
val threshold: Double = 0.2
val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
(Vectors.dense(data), Vectors.dense(defaultBinarized))
)).toDF("feature", "expected")
val binarizer: Binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(threshold)
binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x == y, "The feature value is not correct after binarization.")
}
}
test("read/write") { test("read/write") {
val t = new Binarizer() val t = new Binarizer()
.setInputCol("myInputCol") .setInputCol("myInputCol")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment