Skip to content
Snippets Groups Projects
Commit 0d16faab authored by Wayne Zhang's avatar Wayne Zhang Committed by Yanbo Liang
Browse files

[SPARK-20574][ML] Allow Bucketizer to handle non-Double numeric column

## What changes were proposed in this pull request?
Bucketizer currently requires input column to be Double, but the logic should work on any numeric data types. Many practical problems have integer/float data types, and it could get very tedious to manually cast them into Double before calling bucketizer. This PR extends bucketizer to handle all numeric types.

## How was this patch tested?
New test.

Author: Wayne Zhang <actuaryzhang@uber.com>

Closes #17840 from actuaryzhang/bucketizer.
parent bfc8c79c
No related branches found
No related tags found
No related merge requests found
...@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String ...@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
} }
val newCol = bucketizer(filteredDataset($(inputCol))) val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
val newField = prepOutputField(filteredDataset.schema) val newField = prepOutputField(filteredDataset.schema)
filteredDataset.withColumn($(outputCol), newCol, newField.metadata) filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
} }
...@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String ...@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.4.0") @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField(schema)) SchemaUtils.appendColumn(schema, prepOutputField(schema))
} }
......
...@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} ...@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
...@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ...@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setSplits(Array(0.1, 0.8, 0.9)) .setSplits(Array(0.1, 0.8, 0.9))
testDefaultReadWrite(t) testDefaultReadWrite(t)
} }
test("Bucket numeric features") {
val splits = Array(-3.0, 0.0, 3.0)
val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected")
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)
val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
ByteType, DecimalType(10, 0))
for (mType <- types) {
val df = dataFrame.withColumn("feature", col("feature").cast(mType))
bucketizer.transform(df).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The result is not correct after bucketing in type " +
mType.toString + ". " + s"Expected $y but found $x.")
}
}
}
} }
private object BucketizerSuite extends SparkFunSuite { private object BucketizerSuite extends SparkFunSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment