From 5952bdb7df20d007d59f82261095faca3822c6f6 Mon Sep 17 00:00:00 2001 From: vectorijk <jiangkai@gmail.com> Date: Tue, 6 Oct 2015 12:43:28 -0700 Subject: [PATCH] [SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression Implement Python API for AFTSurvivalRegression Author: vectorijk <jiangkai@gmail.com> Closes #8926 from vectorijk/spark-10688. --- python/pyspark/ml/regression.py | 171 +++++++++++++++++++++++++++++++- 1 file changed, 169 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 21d454f900..a0f7f54e65 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -22,8 +22,10 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', - 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', +__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', + 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', + 'GBTRegressor', 'GBTRegressionModel', + 'LinearRegression', 'LinearRegressionModel', 'RandomForestRegressor', 'RandomForestRegressionModel'] @@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels): """ +@inherit_doc +class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol): + """ + Accelerated Failure Time (AFT) Model Survival Regression + + Fit a parametric AFT survival regression model based on the Weibull distribution + of the survival time. + + .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0), 1.0), + ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) + >>> aftsr = AFTSurvivalRegression() + >>> model = aftsr.fit(df) + >>> model.predict(Vectors.dense(6.3)) + 1.0 + >>> model.predictQuantiles(Vectors.dense(6.3)) + DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052]) + >>> model.transform(df).show() + +-----+---------+------+----------+ + |label| features|censor|prediction| + +-----+---------+------+----------+ + | 1.0| [1.0]| 1.0| 1.0| + | 0.0|(1,[],[])| 0.0| 1.0| + +-----+---------+------+----------+ + ... + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + censorCol = Param(Params._dummy(), "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") + quantileProbabilities = \ + Param(Params._dummy(), "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.") + quantilesCol = Param(Params._dummy(), "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=None, quantilesCol=None): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None): + """ + super(AFTSurvivalRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) + #: Param for censor column name + self.censorCol = Param(self, "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") + #: Param for quantile probabilities array + self.quantileProbabilities = \ + Param(self, "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.") + #: Param for quantiles column name + self.quantilesCol = Param(self, "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") + self._setDefault(censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=None, quantilesCol=None): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None): + """ + kwargs = self.setParams._input_kwargs + if quantileProbabilities is None: + return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5, + 0.75, 0.9, 0.95, 0.99]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return AFTSurvivalRegressionModel(java_model) + + @since("1.6.0") + def setCensorCol(self, value): + """ + Sets the value of :py:attr:`censorCol`. + """ + self._paramMap[self.censorCol] = value + return self + + @since("1.6.0") + def getCensorCol(self): + """ + Gets the value of censorCol or its default value. + """ + return self.getOrDefault(self.censorCol) + + @since("1.6.0") + def setQuantileProbabilities(self, value): + """ + Sets the value of :py:attr:`quantileProbabilities`. + """ + self._paramMap[self.quantileProbabilities] = value + return self + + @since("1.6.0") + def getQuantileProbabilities(self): + """ + Gets the value of quantileProbabilities or its default value. + """ + return self.getOrDefault(self.quantileProbabilities) + + @since("1.6.0") + def setQuantilesCol(self, value): + """ + Sets the value of :py:attr:`quantilesCol`. + """ + self._paramMap[self.quantilesCol] = value + return self + + @since("1.6.0") + def getQuantilesCol(self): + """ + Gets the value of quantilesCol or its default value. + """ + return self.getOrDefault(self.quantilesCol) + + +class AFTSurvivalRegressionModel(JavaModel): + """ + Model fitted by AFTSurvivalRegression. + + .. versionadded:: 1.6.0 + """ + + def predictQuantiles(self, features): + """ + Predicted Quantiles + """ + return self._call_java("predictQuantiles", features) + + def predict(self, features): + """ + Predicted value + """ + return self._call_java("predict", features) + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext -- GitLab