Skip to content
Snippets Groups Projects
Commit 234f781a authored by sethah's avatar sethah Committed by Nick Pentreath
Browse files

[SPARK-13787][ML][PYSPARK] Pyspark feature importances for decision tree and random forest

## What changes were proposed in this pull request?

This patch adds a `featureImportance` property to the Pyspark API for `DecisionTreeRegressionModel`, `DecisionTreeClassificationModel`, `RandomForestRegressionModel` and `RandomForestClassificationModel`.

## How was this patch tested?

Python doc tests for the affected classes were updated to check feature importances.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #11622 from sethah/SPARK-13787.
parent 0b713e04
No related branches found
No related tags found
No related merge requests found
......@@ -285,6 +285,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
3
>>> model.depth
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
......@@ -352,6 +354,27 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
This feature importance is calculated as follows:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestClassifier`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
......@@ -375,6 +398,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
......@@ -443,6 +468,25 @@ class RandomForestClassificationModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
This feature importance is calculated as follows:
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
"""
return self._call_java("featureImportances")
@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
......
......@@ -401,6 +401,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
1
>>> model.numNodes
3
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -499,6 +501,27 @@ class DecisionTreeRegressionModel(DecisionTreeModel):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
This feature importance is calculated as follows:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestRegressor`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
@inherit_doc
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
......@@ -515,6 +538,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
......@@ -579,6 +604,25 @@ class RandomForestRegressionModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses,
following the explanation of Gini importance from "Random Forests" documentation
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
This feature importance is calculated as follows:
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
"""
return self._call_java("featureImportances")
@inherit_doc
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment