Skip to content
Snippets Groups Projects
Commit f3be369e authored by Tommy YU's avatar Tommy YU Committed by Xiangrui Meng
Browse files

[SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression

Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/regression.py.

yanboliang Please help to review.
For doctest, I though it's enough to add one since it's common usage. But I can add to all if we want it.

Author: Tommy YU <tummyyu@163.com>

Closes #11000 from Wenpei/spark-13033-ml.regression-exprot-import and squashes the following commits:

3646b36 [Tommy YU] address review comments
9cddc98 [Tommy YU] change base on review and pr 11197
cc61d9d [Tommy YU] remove default parameter set
19535d4 [Tommy YU] add export/import to regression
44a9dc2 [Tommy YU] add import/export for ml.regression
parent 90d07154
No related branches found
No related tags found
No related merge requests found
......@@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasWeightCol):
HasWeightCol, MLWritable, MLReadable):
"""
.. note:: Experimental
......@@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
0.0
>>> model.boundaries
DenseVector([0.0, 1.0])
>>> ir_path = temp_path + "/ir"
>>> ir.save(ir_path)
>>> ir2 = IsotonicRegression.load(ir_path)
>>> ir2.getIsotonic()
True
>>> model_path = temp_path + "/ir_model"
>>> model.save(model_path)
>>> model2 = IsotonicRegressionModel.load(model_path)
>>> model.boundaries == model2.boundaries
True
>>> model.predictions == model2.predictions
True
"""
isotonic = \
......@@ -237,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.featureIndex)
class IsotonicRegressionModel(JavaModel):
class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable):
"""
.. note:: Experimental
......@@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels):
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasFitIntercept, HasMaxIter, HasTol):
HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
......@@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
| 0.0|(1,[],[])| 0.0| 1.0|
+-----+---------+------+----------+
...
>>> aftsr_path = temp_path + "/aftsr"
>>> aftsr.save(aftsr_path)
>>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
>>> aftsr2.getMaxIter()
100
>>> model_path = temp_path + "/aftsr_model"
>>> model.save(model_path)
>>> model2 = AFTSurvivalRegressionModel.load(model_path)
>>> model.coefficients == model2.coefficients
True
>>> model.intercept == model2.intercept
True
>>> model.scale == model2.scale
True
.. versionadded:: 1.6.0
"""
......@@ -787,7 +813,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.quantilesCol)
class AFTSurvivalRegressionModel(JavaModel):
class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by AFTSurvivalRegression.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment