diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e29d7f48a1d6b8abe97cd7e1e20548c6836936de..aa92edde7acd10b320684d970a0aec733dfb9ec5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params
 
 /**
  * :: DeveloperApi ::
- * Abstraction for prediction problems (regression and classification).
+ * Abstraction for prediction problems (regression and classification). It accepts all NumericType
+ * labels and will automatically cast it to DoubleType in [[fit()]].
  *
  * @tparam FeaturesType  Type of features.
  *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
@@ -87,7 +88,12 @@ abstract class Predictor[
     // This handles a few items such as schema validation.
     // Developers only need to implement train().
     transformSchema(dataset.schema, logging = true)
-    copyValues(train(dataset).setParent(this))
+
+    // Cast LabelCol to DoubleType and keep the metadata.
+    val labelMeta = dataset.schema($(labelCol)).metadata
+    val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+    copyValues(train(casted).setParent(this))
   }
 
   override def copy(extra: ParamMap): Learner
@@ -121,7 +127,7 @@ abstract class Predictor[
    * and put it in an RDD with strong types.
    */
   protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
-    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
       case Row(label: Double, features: Vector) => LabeledPoint(label, features)
     }
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d1b21b16f234235fdb149782487576235a3c7a8a..a3da3067e1b5fcca470eaf83279db23de1c637ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -71,7 +71,7 @@ abstract class Classifier[
    * and put it in an RDD with strong types.
    *
    * @param dataset  DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
-   *                 and features ([[Vector]]). Labels are cast to [[DoubleType]].
+   *                 and features ([[Vector]]).
    * @param numClasses  Number of classes label can take.  Labels must be integers in the range
    *                    [0, numClasses).
    * @throws SparkException  if any label is not an integer >= 0
@@ -79,7 +79,7 @@ abstract class Classifier[
   protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
     require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
       s" $numClasses, but requires numClasses > 0.")
-    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
       case Row(label: Double, features: Vector) =>
         require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
           s" dataset with invalid label $label.  Labels must be integers in range" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 8bffe0cda0327ac8f42b57888de08c0e1b4263a7..f8f164e8c14bde0913d4d782bddd9c69b66b8f62 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
     // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
     // 2 classes now.  This lets us provide a more precise error message.
     val oldDataset: RDD[LabeledPoint] =
-      dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
         case Row(label: Double, features: Vector) =>
           require(label == 0 || label == 1, s"GBTClassifier was given" +
             s" dataset with invalid label $label.  Labels must be in {0,1}; note that" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fdaae04c42ec96a33f1dac9d7fb08ff111efcd4..c4651054fd7653231c1648424d760571f0bc6340 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
       LogisticRegressionModel = {
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
         case Row(label: Double, weight: Double, features: Vector) =>
           Instance(label, weight, features)
       }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 994ed993c99dfac5525ed094757d1c7ae2cdcf0d..b03a07a6bc1e79b33dc5665cd292361096670ad1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
     // Aggregates term frequencies per label.
     // TODO: Calling aggregateByKey and collect creates two stages, we can implement something
     // TODO: similar to reduceByKeyLocally to save one stage.
-    val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
+    val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
       .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
       }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
       seqOp = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 33cb25c8c7f663b6655ebe1555c26850e219b3eb..8656ecf609ea40bd07fcf58c9d6c6e36730cd7d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
 
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
         case Row(label: Double, weight: Double, features: Vector) =>
           Instance(label, weight, features)
       }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 519f3bdec82dfdaf1d1b01c4a6f49228dffb9364..ae876b3839734da52ff707d5c687956487303656 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
 
     val instances: RDD[Instance] = dataset.select(
-      col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      col($(labelCol)), w, col($(featuresCol))).rdd.map {
       case Row(label: Double, weight: Double, features: Vector) =>
         Instance(label, weight, features)
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..03e0c536a973e909c63d1d425c38c4688e5e4425
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import PredictorSuite._
+
+  test("should support all NumericType labels and not support other types") {
+    val df = spark.createDataFrame(Seq(
+      (0, Vectors.dense(0, 2, 3)),
+      (1, Vectors.dense(0, 3, 9)),
+      (0, Vectors.dense(0, 2, 6))
+    )).toDF("label", "features")
+
+    val types =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+
+    val predictor = new MockPredictor()
+
+    types.foreach { t =>
+      predictor.fit(df.select(col("label").cast(t), col("features")))
+    }
+
+    intercept[IllegalArgumentException] {
+      predictor.fit(df.select(col("label").cast(StringType), col("features")))
+    }
+  }
+}
+
+object PredictorSuite {
+
+  class MockPredictor(override val uid: String)
+    extends Predictor[Vector, MockPredictor, MockPredictionModel] {
+
+    def this() = this(Identifiable.randomUID("mockpredictor"))
+
+    override def train(dataset: Dataset[_]): MockPredictionModel = {
+      require(dataset.schema("label").dataType == DoubleType)
+      new MockPredictionModel(uid)
+    }
+
+    override def copy(extra: ParamMap): MockPredictor =
+      throw new NotImplementedError()
+  }
+
+  class MockPredictionModel(override val uid: String)
+    extends PredictionModel[Vector, MockPredictionModel] {
+
+    def this() = this(Identifiable.randomUID("mockpredictormodel"))
+
+    override def predict(features: Vector): Double =
+      throw new NotImplementedError()
+
+    override def copy(extra: ParamMap): MockPredictionModel =
+      throw new NotImplementedError()
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index bc631dc6d31497349eb70906d65d63b1e8bd8a9c..8771fd2e9d2b25a9b08bcd5be4882f78e1eb0af8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
         .objectiveHistory
         .sliding(2)
         .forall(x => x(0) >= x(1)))
-
   }
 
   test("binary logistic regression with weighted data") {