Skip to content
Snippets Groups Projects
Commit 39b44cb5 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Xiangrui Meng
Browse files

[SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #8685 from yu-iskw/SPARK-10278.
parent 0ded87a4
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ from __future__ import absolute_import ...@@ -19,7 +19,7 @@ from __future__ import absolute_import
import random import random
from pyspark import SparkContext, RDD from pyspark import SparkContext, RDD, since
from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.regression import LabeledPoint
...@@ -30,6 +30,11 @@ __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', ...@@ -30,6 +30,11 @@ __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
"""TreeEnsembleModel
.. versionadded:: 1.3.0
"""
@since("1.3.0")
def predict(self, x): def predict(self, x):
""" """
Predict values for a single data point or an RDD of points using Predict values for a single data point or an RDD of points using
...@@ -45,12 +50,14 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): ...@@ -45,12 +50,14 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
else: else:
return self.call("predict", _convert_to_vector(x)) return self.call("predict", _convert_to_vector(x))
@since("1.3.0")
def numTrees(self): def numTrees(self):
""" """
Get number of trees in ensemble. Get number of trees in ensemble.
""" """
return self.call("numTrees") return self.call("numTrees")
@since("1.3.0")
def totalNumNodes(self): def totalNumNodes(self):
""" """
Get total number of nodes, summed over all trees in the Get total number of nodes, summed over all trees in the
...@@ -62,6 +69,7 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): ...@@ -62,6 +69,7 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
""" Summary of model """ """ Summary of model """
return self._java_model.toString() return self._java_model.toString()
@since("1.3.0")
def toDebugString(self): def toDebugString(self):
""" Full model """ """ Full model """
return self._java_model.toDebugString() return self._java_model.toDebugString()
...@@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): ...@@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
.. note:: Experimental .. note:: Experimental
A decision tree model for classification or regression. A decision tree model for classification or regression.
.. versionadded:: 1.1.0
""" """
@since("1.1.0")
def predict(self, x): def predict(self, x):
""" """
Predict the label of one or more examples. Predict the label of one or more examples.
...@@ -90,16 +101,23 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): ...@@ -90,16 +101,23 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
else: else:
return self.call("predict", _convert_to_vector(x)) return self.call("predict", _convert_to_vector(x))
@since("1.1.0")
def numNodes(self): def numNodes(self):
"""Get number of nodes in tree, including leaf nodes."""
return self._java_model.numNodes() return self._java_model.numNodes()
@since("1.1.0")
def depth(self): def depth(self):
"""Get depth of tree.
E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
"""
return self._java_model.depth() return self._java_model.depth()
def __repr__(self): def __repr__(self):
""" summary of model. """ """ summary of model. """
return self._java_model.toString() return self._java_model.toString()
@since("1.2.0")
def toDebugString(self): def toDebugString(self):
""" full model. """ """ full model. """
return self._java_model.toDebugString() return self._java_model.toDebugString()
...@@ -115,6 +133,8 @@ class DecisionTree(object): ...@@ -115,6 +133,8 @@ class DecisionTree(object):
Learning algorithm for a decision tree model for classification or Learning algorithm for a decision tree model for classification or
regression. regression.
.. versionadded:: 1.1.0
""" """
@classmethod @classmethod
...@@ -127,6 +147,7 @@ class DecisionTree(object): ...@@ -127,6 +147,7 @@ class DecisionTree(object):
return DecisionTreeModel(model) return DecisionTreeModel(model)
@classmethod @classmethod
@since("1.1.0")
def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0): minInfoGain=0.0):
...@@ -185,6 +206,7 @@ class DecisionTree(object): ...@@ -185,6 +206,7 @@ class DecisionTree(object):
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@classmethod @classmethod
@since("1.1.0")
def trainRegressor(cls, data, categoricalFeaturesInfo, def trainRegressor(cls, data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0): minInfoGain=0.0):
...@@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): ...@@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader):
.. note:: Experimental .. note:: Experimental
Represents a random forest model. Represents a random forest model.
.. versionadded:: 1.2.0
""" """
@classmethod @classmethod
...@@ -252,6 +276,8 @@ class RandomForest(object): ...@@ -252,6 +276,8 @@ class RandomForest(object):
Learning algorithm for a random forest model for classification or Learning algorithm for a random forest model for classification or
regression. regression.
.. versionadded:: 1.2.0
""" """
supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
...@@ -271,6 +297,7 @@ class RandomForest(object): ...@@ -271,6 +297,7 @@ class RandomForest(object):
return RandomForestModel(model) return RandomForestModel(model)
@classmethod @classmethod
@since("1.2.0")
def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32,
seed=None): seed=None):
...@@ -352,6 +379,7 @@ class RandomForest(object): ...@@ -352,6 +379,7 @@ class RandomForest(object):
maxDepth, maxBins, seed) maxDepth, maxBins, seed)
@classmethod @classmethod
@since("1.2.0")
def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto",
impurity="variance", maxDepth=4, maxBins=32, seed=None): impurity="variance", maxDepth=4, maxBins=32, seed=None):
""" """
...@@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): ...@@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
.. note:: Experimental .. note:: Experimental
Represents a gradient-boosted tree model. Represents a gradient-boosted tree model.
.. versionadded:: 1.3.0
""" """
@classmethod @classmethod
...@@ -431,6 +461,8 @@ class GradientBoostedTrees(object): ...@@ -431,6 +461,8 @@ class GradientBoostedTrees(object):
Learning algorithm for a gradient boosted trees model for Learning algorithm for a gradient boosted trees model for
classification or regression. classification or regression.
.. versionadded:: 1.3.0
""" """
@classmethod @classmethod
...@@ -443,6 +475,7 @@ class GradientBoostedTrees(object): ...@@ -443,6 +475,7 @@ class GradientBoostedTrees(object):
return GradientBoostedTreesModel(model) return GradientBoostedTreesModel(model)
@classmethod @classmethod
@since("1.3.0")
def trainClassifier(cls, data, categoricalFeaturesInfo, def trainClassifier(cls, data, categoricalFeaturesInfo,
loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
maxBins=32): maxBins=32):
...@@ -505,6 +538,7 @@ class GradientBoostedTrees(object): ...@@ -505,6 +538,7 @@ class GradientBoostedTrees(object):
loss, numIterations, learningRate, maxDepth, maxBins) loss, numIterations, learningRate, maxDepth, maxBins)
@classmethod @classmethod
@since("1.3.0")
def trainRegressor(cls, data, categoricalFeaturesInfo, def trainRegressor(cls, data, categoricalFeaturesInfo,
loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
maxBins=32): maxBins=32):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment