diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index d1f3b2af1e4824b8e5adf29da48034a16c3e9152..bb8f2a3aa5f71c5cbd43e051091eb6db9e3bdff9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
       Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
     }
 
-    val newCol = bucketizer(filteredDataset($(inputCol)))
+    val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
     val newField = prepOutputField(filteredDataset.schema)
     filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
   }
@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(inputCol))
     SchemaUtils.appendColumn(schema, prepOutputField(schema))
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index aac29137d79116c68aa14404ef90ee5c4b2d102d..420fb17ddce8cfc84d6a72f114d7a53b3ca1114f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
 class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       .setSplits(Array(0.1, 0.8, 0.9))
     testDefaultReadWrite(t)
   }
+
+  test("Bucket numeric features") {
+    val splits = Array(-3.0, 0.0, 3.0)
+    val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
+    val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
+    val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCol("feature")
+      .setOutputCol("result")
+      .setSplits(splits)
+
+    val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
+      ByteType, DecimalType(10, 0))
+    for (mType <- types) {
+      val df = dataFrame.withColumn("feature", col("feature").cast(mType))
+      bucketizer.transform(df).select("result", "expected").collect().foreach {
+        case Row(x: Double, y: Double) =>
+          assert(x === y, "The result is not correct after bucketing in type " +
+            mType.toString + ". " + s"Expected $y but found $x.")
+      }
+    }
+  }
 }
 
 private object BucketizerSuite extends SparkFunSuite {