Skip to content
Snippets Groups Projects
Commit b11887c0 authored by sethah's avatar sethah Committed by Joseph K. Bradley
Browse files

[SPARK-14264][PYSPARK][ML] Add feature importance for GBTs in pyspark

## What changes were proposed in this pull request?

Feature importances are exposed in the python API for GBTs.

Other changes:
* Update the random forest feature importance documentation to not repeat decision tree docstring and instead place a reference to it.

## How was this patch tested?

Python doc tests were updated to validate GBT feature importance.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12056 from sethah/Pyspark_GBT_feature_importance.
parent e7854028
No related branches found
No related tags found
No related merge requests found
......@@ -396,7 +396,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestClassifier`
correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
......@@ -500,16 +500,12 @@ class RandomForestClassificationModel(TreeEnsembleModels):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
This feature importance is calculated as follows:
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
.. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
"""
return self._call_java("featureImportances")
......@@ -534,6 +530,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
>>> model = gbt.fit(td)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
......@@ -613,6 +611,21 @@ class GBTClassificationModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
.. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
"""
return self._call_java("featureImportances")
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
......
......@@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestRegressor`
correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
......@@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
This feature importance is calculated as follows:
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
.. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
"""
return self._call_java("featureImportances")
......@@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
......@@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
.. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
"""
return self._call_java("featureImportances")
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment