diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 07cafa099374d8f0d7882a566d45e65fd269c334..f5335a3114b183fa3f66a190cd540c73e6241a2c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -396,7 +396,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR - Normalize importances for tree to sum to 1. Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :class:`RandomForestClassifier` + correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` to determine feature importance instead. """ return self._call_java("featureImportances") @@ -500,16 +500,12 @@ class RandomForestClassificationModel(TreeEnsembleModels): """ Estimate of the importance of each feature. - This generalizes the idea of "Gini" importance to other losses, - following the explanation of Gini importance from "Random Forests" documentation - by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. - This feature importance is calculated as follows: - - Average over trees: - - importance(feature j) = sum (over nodes which split on feature j) of the gain, - where gain is scaled by the number of instances passing through node - - Normalize importances for tree to sum to 1. - - Normalize feature importance vector to sum to 1. + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` """ return self._call_java("featureImportances") @@ -534,6 +530,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) >>> model = gbt.fit(td) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -613,6 +611,21 @@ class GBTClassificationModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 37648549dee207d86e96e0aaa6f1f6ae9a4fb2b4..de8a5e4bed2e2a1f251ddf0896cd6ae6ded25507 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada - Normalize importances for tree to sum to 1. Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :class:`RandomForestRegressor` + correlated predictor variables. Consider using a :py:class:`RandomForestRegressor` to determine feature importance instead. """ return self._call_java("featureImportances") @@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels): """ Estimate of the importance of each feature. - This generalizes the idea of "Gini" importance to other losses, - following the explanation of Gini importance from "Random Forests" documentation - by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. - This feature importance is calculated as follows: - - Average over trees: - - importance(feature j) = sum (over nodes which split on feature j) of the gain, - where gain is scaled by the number of instances passing through node - - Normalize importances for tree to sum to 1. - - Normalize feature importance vector to sum to 1. + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` """ return self._call_java("featureImportances") @@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> model = gbt.fit(df) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,