diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6b8abe97cd7e1e20548c6836936de..aa92edde7acd10b320684d970a0aec733dfb9ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f234235fdb149782487576235a3c7a8a..a3da3067e1b5fcca470eaf83279db23de1c637ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda0327ac8f42b57888de08c0e1b4263a7..f8f164e8c14bde0913d4d782bddd9c69b66b8f62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec96a33f1dac9d7fb08ff111efcd4..c4651054fd7653231c1648424d760571f0bc6340 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99dfac5525ed094757d1c7ae2cdcf0d..b03a07a6bc1e79b33dc5665cd292361096670ad1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f663b6655ebe1555c26850e219b3eb..8656ecf609ea40bd07fcf58c9d6c6e36730cd7d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82dfdaf1d1b01c4a6f49228dffb9364..ae876b3839734da52ff707d5c687956487303656 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..03e0c536a973e909c63d1d425c38c4688e5e4425 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d31497349eb70906d65d63b1e8bd8a9c..8771fd2e9d2b25a9b08bcd5be4882f78e1eb0af8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") {