diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 922f8069fac49592f166c7738413ea8a0b962f04..6ef119a4265fdfc2afbaf7075f0f762f4fd8e10f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -739,7 +739,8 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -767,6 +768,18 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtc_path = temp_path + "gbtc" + >>> gbt.save(gbtc_path) + >>> gbt2 = GBTClassifier.load(gbtc_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtc_model" + >>> model.save(model_path) + >>> model2 = GBTClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -831,7 +844,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol return self.getOrDefault(self.lossType) -class GBTClassificationModel(TreeEnsembleModels): +class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index c064fe500c3c0f67a6e4d76542b8369f6f49c65c..3c7852526a4812a5c9e833d67557aec3f99396ae 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -902,7 +902,8 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -925,6 +926,18 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtr_path = temp_path + "gbtr" + >>> gbt.save(gbtr_path) + >>> gbt2 = GBTRegressor.load(gbtr_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtr_model" + >>> model.save(model_path) + >>> model2 = GBTRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -989,7 +1002,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModels): +class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTRegressor.