Skip to content
Snippets Groups Projects
Commit 1dbc4a15 authored by MechCoder's avatar MechCoder Committed by Xiangrui Meng
Browse files

[SPARK-8711] [ML] Add additional methods to PySpark ML tree models

Add numNodes and depth to treeModels, add treeWeights to ensemble Models.
Add __repr__ to all models.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits:

23b08be [MechCoder] private [spark]
38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights
6d16ad8 [MechCoder] Fix Python 3 Error
47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods
819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
parent 0a63d7ab
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@
package org.apache.spark.ml.tree
import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* Abstraction for Decision Tree models.
......@@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
/** Weights used by the python wrappers. */
// Note: An array cannot be returned directly due to serialization problems.
private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
......
......@@ -18,7 +18,8 @@
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.regression import RandomForestParams
from pyspark.ml.regression import (
RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
......@@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
>>> model.numNodes
3
>>> model.depth
1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -269,7 +274,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return self.getOrDefault(self.impurity)
class DecisionTreeClassificationModel(JavaModel):
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeClassifier.
"""
......@@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
......@@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -423,7 +432,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return self.getOrDefault(self.featureSubsetStrategy)
class RandomForestClassificationModel(JavaModel):
class RandomForestClassificationModel(TreeEnsembleModels):
"""
Model fitted by RandomForestClassifier.
"""
......@@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
......@@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
>>> model = gbt.fit(td)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -558,7 +570,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
return self.getOrDefault(self.stepSize)
class GBTClassificationModel(JavaModel):
class GBTClassificationModel(TreeEnsembleModels):
"""
Model fitted by GBTClassifier.
"""
......
......@@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2)
>>> model = dt.fit(df)
>>> model.depth
1
>>> model.numNodes
3
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -239,7 +243,37 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.impurity)
class DecisionTreeRegressionModel(JavaModel):
@inherit_doc
class DecisionTreeModel(JavaModel):
@property
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")
@property
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")
def __repr__(self):
return self._call_java("toString")
@inherit_doc
class TreeEnsembleModels(JavaModel):
@property
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))
def __repr__(self):
return self._call_java("toString")
@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeRegressor.
"""
......@@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> allclose(model.treeWeights, [1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -389,7 +426,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.featureSubsetStrategy)
class RandomForestRegressionModel(JavaModel):
class RandomForestRegressionModel(TreeEnsembleModels):
"""
Model fitted by RandomForestRegressor.
"""
......@@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
>>> model = gbt.fit(df)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
......@@ -518,7 +558,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
return self.getOrDefault(self.stepSize)
class GBTRegressionModel(JavaModel):
class GBTRegressionModel(TreeEnsembleModels):
"""
Model fitted by GBTRegressor.
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment