diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index cfecae7e0b152c7d535158f5ec2d09cbbd101f72..29f55a7f715ca5b29c933abf53a2b6f0e245f00b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { + val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String) override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5219680be2dc87149000e7f7f42ac1c6a04c7a99..a2f3d44132d1d4e982167b8dd6c76ba2f0e739b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -256,8 +256,8 @@ class RFormulaModel private[feature]( val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, - "Label column already exists and is not of type DoubleType.") + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index f94d336df544d429af5d4b40f45c56c08af29d35..91a947f44bc31639bc7d8e93d838c1e041e8d8d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( - dt, isClassification = true, spark) { (expected, actual) => + dt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index c9453aaec255df99a292b433ac1d231f551ea709..5a5e5c15fc59c3bf00f10890307881d784ae6e2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( - gbt, isClassification = true, spark) { (expected, actual) => + gbt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cb4d087ce5bc2051183ee3cf20be71126c83f75a..f127aa217c94d60917fd8384e7d7107190db36c9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -938,7 +938,7 @@ class LogisticRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LogisticRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( - lr, isClassification = true, spark) { (expected, actual) => + lr, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients.toArray === actual.coefficients.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 876e047db54cfd9ca030bc310f82793d856b7b71..d5282e07d65c228e9e52974fc65b4c9e78d26dd5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) MLTestingUtils.checkNumericTypes[ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( - mpc, isClassification = true, spark) { (expected, actual) => + mpc, spark) { (expected, actual) => assert(expected.layers === actual.layers) assert(expected.weights === actual.weights) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 15d0358c3fc0370fe2f20c1e415b3e3c734c28f6..2a05c446e5169a010f7831b4f3800c382531c296 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("should support all NumericType labels and not support other types") { val nb = new NaiveBayes() MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( - nb, isClassification = true, spark) { (expected, actual) => + nb, spark) { (expected, actual) => assert(expected.pi === actual.pi) assert(expected.theta === actual.theta) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 005d609307fbfdbe8542caacca07b479cf5f3443..5044d40998d66558b86725f380b3de903b3091f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("should support all NumericType labels and not support other types") { val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( - ovr, isClassification = true, spark) { (expected, actual) => + ovr, spark) { (expected, actual) => val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) assert(expectedModels.length === actualModels.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 97f3feacca0772b38c4a295247881bfbf8c3a525..8002a2f4f29e108baacd0668ab6a9edfeb207afb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -189,7 +189,7 @@ class RandomForestClassifierSuite test("should support all NumericType labels and not support other types") { val rf = new RandomForestClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( - rf, isClassification = true, spark) { (expected, actual) => + rf, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 4c6d9c5e26097cac17640e2d83c856f420a52924..4fcc9745b738c60fa6806e29ebca0ee883541923 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -81,4 +81,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext val newInstance = testDefaultReadWrite(instance) assert(newInstance.selectedFeatures === instance.selectedFeatures) } + + test("should support all NumericType labels and not support other types") { + val css = new ChiSqSelector() + MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector]( + css, spark) { (expected, actual) => + assert(expected.selectedFeatures === actual.selectedFeatures) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 46e7495297a6b9de0e56b170a801eff1b015f299..c623a6210bdae7df3cb898754cb947f4e6c21894 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { @@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } - test("label column already exists but is not double type") { + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = spark.createDataFrame( Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), @@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul "vec2", Array[Attribute]( NumericAttribute.defaultAttr, - NumericAttribute.defaultAttr)).toMetadata + NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) val result = model.transform(original) @@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newModel = testDefaultReadWrite(model) checkModelData(model, newModel) } + + test("should support all NumericType labels") { + val formula = new RFormula().setFormula("label ~ features") + .setLabelCol("x") + .setFeaturesCol("y") + val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark) + val expected = formula.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t))) + actuals.foreach { actual => + assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length) + expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach { + case (exTransformer, acTransformer) => + assert(exTransformer.params === acTransformer.params) + } + assert(expected.resolvedFormula.label === actual.resolvedFormula.label) + assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) + assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index f8fc775676c0caf0ddc5af41561048399a5977f4..e4772df622d18d5084526e69eb752c72176c6c7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite test("should support all NumericType labels") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( - aft, isClassification = false, spark) { (expected, actual) => + aft, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index d9f26ad8dc93c4424f8073706927e17568031652..2d30cbf36766b3a63433f5494838ad61c89f7491 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( - dt, isClassification = false, spark) { (expected, actual) => + dt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index f6ea5bb741d41cf4ca55154d1759f9d93ff0583b..ac833b833d7d83f686790e355e9b8461863d774d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( - gbt, isClassification = false, spark) { (expected, actual) => + gbt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 161f8c80f8df5bdcc55498c05e793a0be4f1281d..3d9aeb8c0a2d0e155b7cd520ad99097c43e4334c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( - glr, isClassification = false, spark) { (expected, actual) => + glr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 9bf7542b1259b3ca194fbe406fea4d985ba084dd..bed4978b25b37c5ff3cc56cf98342d49818e1493 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -184,7 +184,7 @@ class IsotonicRegressionSuite test("should support all NumericType labels and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( - ir, isClassification = false, spark) { (expected, actual) => + ir, spark, isClassification = false) { (expected, actual) => assert(expected.boundaries === actual.boundaries) assert(expected.predictions === actual.predictions) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 10f547b67375c4a86873f37ffdbf6b1a1ff95931..a98227d2c14fb9c0f65fc6d6a05904a9bc8c1ee4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -1010,7 +1010,7 @@ class LinearRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, isClassification = false, spark) { (expected, actual) => + lr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 72f3c65eb8c7d6c780144691a6eb04e9d18e2677..7a3a3698f950d1d74dbac9b05a7f9301f2252002 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex test("should support all NumericType labels and not support other types") { val rf = new RandomForestRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( - rf, isClassification = false, spark) { (expected, actual) => + rf, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 4fe473bbacd426028fc6df145b2b28cff24882db..ad7d2c9b8d40c7a36230af73eebb4554681ef9c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( estimator: T, - isClassification: Boolean, - spark: SparkSession)(check: (M, M) => Unit): Unit = { + spark: SparkSession, + isClassification: Boolean = true)(check: (M, M) => Unit): Unit = { val dfs = if (isClassification) { genClassifDFWithNumericLabelCol(spark) } else {