Skip to content
Snippets Groups Projects
Commit 3aa34882 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier &...

[SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier & DecisionTreeRegressor should support setSeed

PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9807 from yanboliang/spark-11815.
parent 95eb6516
No related branches found
No related tags found
No related merge requests found
......@@ -273,7 +273,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
TreeClassifierParams, HasCheckpointInterval):
TreeClassifierParams, HasCheckpointInterval, HasSeed):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for classification.
......@@ -313,12 +313,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None)
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
......@@ -335,12 +337,13 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini"):
impurity="gini", seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None)
Sets params for the DecisionTreeClassifier.
"""
kwargs = self.setParams._input_kwargs
......
......@@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval):
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
HasSeed):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
......@@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", seed=None)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
......@@ -435,11 +438,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance"):
impurity="variance", seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", seed=None)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self.setParams._input_kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment