diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af7525a1297289b1d52b17ed983a48..4b3608330c1bf9b2e47e3fd132b9ea2123babc75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params fitting: Boolean): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { - SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { @@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { - dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) - .rdd.map { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), + col($(censorCol)).cast(DoubleType)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a6c29433d7303410f13639ceface85994d123ae0..529f66eadbcff578833da4ca1a22b62410a90ab4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures */ final val isotonic: BooleanParam = new BooleanParam(this, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + + "whether the output sequence should be isotonic/increasing (true) or " + "antitonic/decreasing (false)") /** @group getParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38d90b70abe55d9c490cd555fedef6..3cd4b0ac308efae7952d09041f3bd463e6542b47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types._ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite } } - test("should support all NumericType labels") { + test("should support all NumericType labels, and not support other types") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( aft, spark, isClassification = false) { (expected, actual) => @@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType censors, and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF("label", "features") + .withColumn("censor", lit(0.0)) + val aft = new AFTSurvivalRegression().setMaxIter(1) + val expected = aft.fit(df) + + val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0)) + types.foreach { t => + val actual = aft.fit(df.select(col("label"), col("features"), + col("censor").cast(t))) + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + + val dfWithStringCensors = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3), "0") + )).toDF("label", "features", "censor") + val thrown = intercept[IllegalArgumentException] { + aft.fit(dfWithStringCensors) + } + assert(thrown.getMessage.contains( + "Column censor must be of type NumericType but was actually of type StringType")) + } + test("numerical stability of standardization") { val trainer = new AFTSurvivalRegression() val model1 = trainer.fit(datasetUnivariate)