From 3aa3488225af12a77da3ba807906bc6a461ef11c Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Wed, 6 Jan 2016 10:52:25 -0800
Subject: [PATCH] [SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier &
 DecisionTreeRegressor should support setSeed

PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9807 from yanboliang/spark-11815.
---
 python/pyspark/ml/classification.py | 13 ++++++++-----
 python/pyspark/ml/regression.py     | 14 +++++++++-----
 2 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5599b8f3ec..265c6a14f1 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -273,7 +273,7 @@ class GBTParams(TreeEnsembleParams):
 @inherit_doc
 class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
                              HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
-                             TreeClassifierParams, HasCheckpointInterval):
+                             TreeClassifierParams, HasCheckpointInterval, HasSeed):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for classification.
@@ -313,12 +313,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  probabilityCol="probability", rawPredictionCol="rawPrediction",
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
+                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
+                 seed=None):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
+                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
+                 seed=None)
         """
         super(DecisionTreeClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -335,12 +337,13 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
                   probabilityCol="probability", rawPredictionCol="rawPrediction",
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                  impurity="gini"):
+                  impurity="gini", seed=None):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   probabilityCol="probability", rawPredictionCol="rawPrediction", \
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
-                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
+                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
+                  seed=None)
         Sets params for the DecisionTreeClassifier.
         """
         kwargs = self.setParams._input_kwargs
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index a0bb8ceed8..401bac0223 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams):
 
 @inherit_doc
 class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
-                            DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval):
+                            DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
+                            HasSeed):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for regression.
@@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
+                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
+                 seed=None):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+                 impurity="variance", seed=None)
         """
         super(DecisionTreeRegressor, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -435,11 +438,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                  impurity="variance"):
+                  impurity="variance", seed=None):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
-                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+                  impurity="variance", seed=None)
         Sets params for the DecisionTreeRegressor.
         """
         kwargs = self.setParams._input_kwargs
-- 
GitLab