From e963070c13f56fbc2dfaf9f5d4e69d34afd0957c Mon Sep 17 00:00:00 2001
From: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Date: Sun, 1 Nov 2015 23:52:50 -0800
Subject: [PATCH] [SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9402 from yu-iskw/SPARK-9722.
---
 .../org/apache/spark/ml/tree/impl/RandomForest.scala      | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 96d5652857..4a3b12d144 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging {
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
     timer.start("findSplitsBins")
-    val splits = findSplits(retaggedInput, metadata)
+    val splits = findSplits(retaggedInput, metadata, seed)
     timer.stop("findSplitsBins")
     logDebug("numBins: feature: number of bins")
     logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
@@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging {
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    * @param metadata Learning and dataset metadata
+   * @param seed random seed
    * @return A tuple of (splits, bins).
    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    *          of size (numFeatures, numSplits).
@@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging {
    */
   protected[tree] def findSplits(
       input: RDD[LabeledPoint],
-      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
+      metadata: DecisionTreeMetadata,
+      seed : Long): Array[Array[Split]] = {
 
     logDebug("isMulticlass = " + metadata.isMulticlass)
 
@@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging {
         1.0
       }
       logDebug("fraction of data used for calculating quantiles = " + fraction)
-      input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
+      input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
     } else {
       new Array[LabeledPoint](0)
     }
-- 
GitLab