Skip to content
Snippets Groups Projects
Commit 50c08e82 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Nick Pentreath
Browse files

[SPARK-19704][ML] AFTSurvivalRegression should support numeric censorCol

## What changes were proposed in this pull request?
make `AFTSurvivalRegression` support numeric censorCol
## How was this patch tested?
existing tests and added tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #17034 from zhengruifeng/aft_numeric_censor.
parent 625cfe09
No related branches found
No related tags found
No related merge requests found
......@@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
fitting: Boolean): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(censorCol))
SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
......@@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* and put it in an RDD with strong types.
*/
protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
.rdd.map {
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType),
col($(censorCol)).cast(DoubleType)).rdd.map {
case Row(features: Vector, label: Double, censor: Double) =>
AFTPoint(features, label, censor)
}
......
......@@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
*/
final val isotonic: BooleanParam =
new BooleanParam(this, "isotonic",
"whether the output sequence should be isotonic/increasing (true) or" +
"whether the output sequence should be isotonic/increasing (true) or " +
"antitonic/decreasing (false)")
/** @group getParam */
......
......@@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._
class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite
}
}
test("should support all NumericType labels") {
test("should support all NumericType labels, and not support other types") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, spark, isClassification = false) { (expected, actual) =>
......@@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite
}
}
test("should support all NumericType censors, and not support other types") {
val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0)),
(1, Vectors.dense(1)),
(2, Vectors.dense(2)),
(3, Vectors.dense(3)),
(4, Vectors.dense(4))
)).toDF("label", "features")
.withColumn("censor", lit(0.0))
val aft = new AFTSurvivalRegression().setMaxIter(1)
val expected = aft.fit(df)
val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0))
types.foreach { t =>
val actual = aft.fit(df.select(col("label"), col("features"),
col("censor").cast(t)))
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
val dfWithStringCensors = spark.createDataFrame(Seq(
(0, Vectors.dense(0, 2, 3), "0")
)).toDF("label", "features", "censor")
val thrown = intercept[IllegalArgumentException] {
aft.fit(dfWithStringCensors)
}
assert(thrown.getMessage.contains(
"Column censor must be of type NumericType but was actually of type StringType"))
}
test("numerical stability of standardization") {
val trainer = new AFTSurvivalRegression()
val model1 = trainer.fit(datasetUnivariate)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment