From 0d16faab90e4cd1f73c5b749dbda7bc2a400b26f Mon Sep 17 00:00:00 2001
From: Wayne Zhang <actuaryzhang@uber.com>
Date: Fri, 5 May 2017 10:23:58 +0800
Subject: [PATCH] [SPARK-20574][ML] Allow Bucketizer to handle non-Double
 numeric column

## What changes were proposed in this pull request?
Bucketizer currently requires input column to be Double, but the logic should work on any numeric data types. Many practical problems have integer/float data types, and it could get very tedious to manually cast them into Double before calling bucketizer. This PR extends bucketizer to handle all numeric types.

## How was this patch tested?
New test.

Author: Wayne Zhang <actuaryzhang@uber.com>

Closes #17840 from actuaryzhang/bucketizer.
---
 .../apache/spark/ml/feature/Bucketizer.scala  |  4 +--
 .../spark/ml/feature/BucketizerSuite.scala    | 25 +++++++++++++++++++
 2 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index d1f3b2af1e..bb8f2a3aa5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
       Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
     }
 
-    val newCol = bucketizer(filteredDataset($(inputCol)))
+    val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
     val newField = prepOutputField(filteredDataset.schema)
     filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
   }
@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(inputCol))
     SchemaUtils.appendColumn(schema, prepOutputField(schema))
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index aac29137d7..420fb17ddc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
 class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       .setSplits(Array(0.1, 0.8, 0.9))
     testDefaultReadWrite(t)
   }
+
+  test("Bucket numeric features") {
+    val splits = Array(-3.0, 0.0, 3.0)
+    val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
+    val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
+    val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCol("feature")
+      .setOutputCol("result")
+      .setSplits(splits)
+
+    val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
+      ByteType, DecimalType(10, 0))
+    for (mType <- types) {
+      val df = dataFrame.withColumn("feature", col("feature").cast(mType))
+      bucketizer.transform(df).select("result", "expected").collect().foreach {
+        case Row(x: Double, y: Double) =>
+          assert(x === y, "The result is not correct after bucketing in type " +
+            mType.toString + ". " + s"Expected $y but found $x.")
+      }
+    }
+  }
 }
 
 private object BucketizerSuite extends SparkFunSuite {
-- 
GitLab