Skip to content
Snippets Groups Projects
Commit e5d8d6e0 authored by Kai Jiang's avatar Kai Jiang Committed by Joseph K. Bradley
Browse files

[SPARK-14373][PYSPARK] PySpark RandomForestClassifier, Regressor support export/import

## What changes were proposed in this pull request?
supporting `RandomForest{Classifier, Regressor}` save/load for Python API.
[JIRA](https://issues.apache.org/jira/browse/SPARK-14373)
## How was this patch tested?
doctest

Author: Kai Jiang <jiangkai@gmail.com>

Closes #12238 from vectorijk/spark-14373.
parent a9b630f4
No related branches found
No related tags found
No related merge requests found
......@@ -621,7 +621,8 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
HasRawPredictionCol, HasProbabilityCol,
RandomForestParams, TreeClassifierParams, HasCheckpointInterval):
RandomForestParams, TreeClassifierParams, HasCheckpointInterval,
JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for classification.
......@@ -655,6 +656,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> rfc_path = temp_path + "/rfc"
>>> rf.save(rfc_path)
>>> rf2 = RandomForestClassifier.load(rfc_path)
>>> rf2.getNumTrees()
3
>>> model_path = temp_path + "/rfc_model"
>>> model.save(model_path)
>>> model2 = RandomForestClassificationModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
.. versionadded:: 1.4.0
"""
......@@ -703,7 +714,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return RandomForestClassificationModel(java_model)
class RandomForestClassificationModel(TreeEnsembleModels):
class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
......
......@@ -782,7 +782,8 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
@inherit_doc
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
RandomForestParams, TreeRegressorParams, HasCheckpointInterval):
RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for regression.
......@@ -805,6 +806,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
>>> rfr_path = temp_path + "/rfr"
>>> rf.save(rfr_path)
>>> rf2 = RandomForestRegressor.load(rfr_path)
>>> rf2.getNumTrees()
2
>>> model_path = temp_path + "/rfr_model"
>>> model.save(model_path)
>>> model2 = RandomForestRegressionModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
.. versionadded:: 1.4.0
"""
......@@ -854,7 +865,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return RandomForestRegressionModel(java_model)
class RandomForestRegressionModel(TreeEnsembleModels):
class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestRegressor.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment