Skip to content
Snippets Groups Projects
Commit 8fa8c837 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Joseph K. Bradley
Browse files

[SPARK-11514][ML] Pass random seed to spark.ml DecisionTree*

cc jkbradley

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9486 from yu-iskw/SPARK-11514.
parent 6091e91f
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String)
override def setImpurity(value: String): this.type = super.setImpurity(value)
override def setSeed(value: Long): this.type = super.setSeed(value)
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
......@@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
......
......@@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value)
override def setSeed(value: Long): this.type = super.setSeed(value)
override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
......
......@@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval {
private[ml] trait DecisionTreeParams extends PredictorParams
with HasCheckpointInterval with HasSeed {
/**
* Maximum depth of the tree (>= 0).
......@@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
/** @group expertSetParam */
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
......@@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams {
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
/**
* Fraction of the training data used for learning each decision tree, in range (0, 1].
......@@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
/** @group getParam */
final def getSubsamplingRate: Double = $(subsamplingRate)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
/**
* Create a Strategy instance to use with the old API.
* NOTE: The caller should set impurity and seed.
......
......@@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
.setImpurity("gini")
.setMaxDepth(2)
.setMaxBins(100)
.setSeed(1)
val categoricalFeatures = Map(0 -> 3, 1-> 3)
val numClasses = 2
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
......
......@@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(100)
.setSeed(1)
val categoricalFeatures = Map(0 -> 3, 1-> 3)
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment