Skip to content
Snippets Groups Projects
Commit 5952bdb7 authored by vectorijk's avatar vectorijk Committed by Xiangrui Meng
Browse files

[SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression

Implement Python API for AFTSurvivalRegression

Author: vectorijk <jiangkai@gmail.com>

Closes #8926 from vectorijk/spark-10688.
parent e9783601
No related branches found
No related tags found
No related merge requests found
......@@ -22,8 +22,10 @@ from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor',
'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel',
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
'LinearRegression', 'LinearRegressionModel',
'RandomForestRegressor', 'RandomForestRegressionModel']
......@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels):
"""
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasFitIntercept, HasMaxIter, HasTol):
"""
Accelerated Failure Time (AFT) Model Survival Regression
Fit a parametric AFT survival regression model based on the Weibull distribution
of the survival time.
.. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0), 1.0),
... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
>>> aftsr = AFTSurvivalRegression()
>>> model = aftsr.fit(df)
>>> model.predict(Vectors.dense(6.3))
1.0
>>> model.predictQuantiles(Vectors.dense(6.3))
DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
>>> model.transform(df).show()
+-----+---------+------+----------+
|label| features|censor|prediction|
+-----+---------+------+----------+
| 1.0| [1.0]| 1.0| 1.0|
| 0.0|(1,[],[])| 0.0| 1.0|
+-----+---------+------+----------+
...
.. versionadded:: 1.6.0
"""
# a placeholder to make it appear in the generated doc
censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.")
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.")
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=None, quantilesCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
quantilesCol=None):
"""
super(AFTSurvivalRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
#: Param for censor column name
self.censorCol = Param(self, "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.")
#: Param for quantile probabilities array
self.quantileProbabilities = \
Param(self, "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.")
#: Param for quantiles column name
self.quantilesCol = Param(self, "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.")
self._setDefault(censorCol="censor",
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=None, quantilesCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
quantilesCol=None):
"""
kwargs = self.setParams._input_kwargs
if quantileProbabilities is None:
return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5,
0.75, 0.9, 0.95, 0.99])
else:
return self._set(**kwargs)
def _create_model(self, java_model):
return AFTSurvivalRegressionModel(java_model)
@since("1.6.0")
def setCensorCol(self, value):
"""
Sets the value of :py:attr:`censorCol`.
"""
self._paramMap[self.censorCol] = value
return self
@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)
@since("1.6.0")
def setQuantileProbabilities(self, value):
"""
Sets the value of :py:attr:`quantileProbabilities`.
"""
self._paramMap[self.quantileProbabilities] = value
return self
@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)
@since("1.6.0")
def setQuantilesCol(self, value):
"""
Sets the value of :py:attr:`quantilesCol`.
"""
self._paramMap[self.quantilesCol] = value
return self
@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)
class AFTSurvivalRegressionModel(JavaModel):
"""
Model fitted by AFTSurvivalRegression.
.. versionadded:: 1.6.0
"""
def predictQuantiles(self, features):
"""
Predicted Quantiles
"""
return self._call_java("predictQuantiles", features)
def predict(self, features):
"""
Predicted value
"""
return self._call_java("predict", features)
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment