Skip to content
Snippets Groups Projects
Commit 56f2c61c authored by Sung Chung's avatar Sung Chung Committed by Xiangrui Meng
Browse files

[SPARK-3161][MLLIB] Adding a node Id caching mechanism for training deci...

...sion trees. jkbradley mengxr chouqin Please review this.

Author: Sung Chung <schung@alpinenow.com>

Closes #2868 from codedeft/SPARK-3161 and squashes the following commits:

5f5a156 [Sung Chung] [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training decision trees.
parent d8176b1c
No related branches found
No related tags found
No related merge requests found
......@@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2) extends AbstractParams[Params]
fracTest: Double = 0.2,
useNodeIdCache: Boolean = false,
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......@@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
opt[Boolean]("useNodeIdCache")
.text(s"whether to use node Id cache during training, " +
s"default: ${defaultParams.useNodeIdCache}")
.action((x, c) => c.copy(useNodeIdCache = x))
opt[String]("checkpointDir")
.text(s"checkpoint directory where intermediate node Id caches will be stored, " +
s"default: ${defaultParams.checkpointDir match {
case Some(strVal) => strVal
case None => "None"
}}")
.action((x, c) => c.copy(checkpointDir = Some(x)))
opt[Int]("checkpointInterval")
.text(s"how often to checkpoint the node Id cache, " +
s"default: ${defaultParams.checkpointInterval}")
.action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
......@@ -236,7 +254,10 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache,
checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
......
......@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
* to pass the entire tree to the executors during
* the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
......@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
timer: TimeTracker = new TimeTracker): Unit = {
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
......@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
/**
* Performs a sequential aggregation over a partition for a particular tree and node.
*
* For each feature, the aggregate sufficient statistics are updated for the relevant
* bins.
*
* @param treeIndex Index of the tree that we want to perform aggregation for.
* @param nodeInfo The node info for the tree node.
* @param agg Array storing aggregate calculation, with a set of sufficient statistics
* for each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
*/
def nodeBinSeqOp(
treeIndex: Int,
nodeInfo: RandomForest.NodeIndexInfo,
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Unit = {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
}
/**
* Performs a sequential aggregation over a partition.
......@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
// If the example does not reach a node in this group, then nodeIndex = null.
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
}
/**
* Do the same thing as binSeqOp, but with nodeIdCache.
*/
def binSeqOpWithNodeIdCache(
agg: Array[DTStatsAggregator],
dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val baggedPoint = dataPoint._1
val nodeIdCache = dataPoint._2
val nodeIndex = nodeIdCache(treeIndex)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg
}
......@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
val nodeToBestSplits =
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}
} else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
......@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}.reduceByKey((a, b) => a.merge(b))
}
}
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
......@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
timer.stop("chooseSplits")
val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
} else {
null
}
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
......@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
split = split,
nodeIndex = nodeIndex)
nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
}
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
......@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
}
}
if (nodeIdCache.nonEmpty) {
// Update the cache if needed.
nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
}
}
/**
......
......@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
......@@ -160,6 +160,19 @@ private class RandomForest (
* in lower levels).
*/
// Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
val nodeIdCache = if (strategy.useNodeIdCache) {
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
None
}
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
......@@ -182,7 +195,7 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}
......@@ -193,6 +206,11 @@ private class RandomForest (
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
nodeIdCache.get.deleteAllCheckpoints()
}
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
val treeWeights = Array.fill[Double](numTrees)(1.0)
new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
......
......@@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
* the node Id cache periodically. This is the checkpoint directory
* to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
......@@ -73,7 +80,10 @@ class Strategy (
@BeanProperty var minInstancesPerNode: Int = 1,
@BeanProperty var minInfoGain: Double = 0.0,
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1) extends Serializable {
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
import scala.collection.mutable
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.rdd.RDD
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
/**
* :: DeveloperApi ::
* This is used by the node id cache to find the child id that a data point would belong to.
* @param split Split information.
* @param nodeIndex The current node index of a data point that this will update.
*/
@DeveloperApi
private[tree] case class NodeIndexUpdater(
split: Split,
nodeIndex: Int) {
/**
* Determine a child node index based on the feature value and the split.
* @param binnedFeatures Binned feature values.
* @param bins Bin information to convert the bin indices to approximate feature values.
* @return Child node index to update to.
*/
def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
if (split.featureType == Continuous) {
val featureIndex = split.feature
val binIndex = binnedFeatures(featureIndex)
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
if (featureValueUpperBound <= split.threshold) {
Node.leftChildIndex(nodeIndex)
} else {
Node.rightChildIndex(nodeIndex)
}
} else {
if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
Node.leftChildIndex(nodeIndex)
} else {
Node.rightChildIndex(nodeIndex)
}
}
}
}
/**
* :: DeveloperApi ::
* A given TreePoint would belong to a particular node per tree.
* Each row in the nodeIdsForInstances RDD is an array over trees of the node index
* in each tree. Initially, values should all be 1 for root node.
* The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)).
* @param checkpointDir The checkpoint directory where
* the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
private[tree] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
val checkpointDir: Option[String],
val checkpointInterval: Int) {
// Keep a reference to a previous node Ids for instances.
// Because we will keep on re-persisting updated node Ids,
// we want to unpersist the previous RDD.
private var prevNodeIdsForInstances: RDD[Array[Int]] = null
// To keep track of the past checkpointed RDDs.
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0
// If a checkpoint directory is given, and there's no prior checkpoint directory,
// then set the checkpoint directory with the given one.
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
}
/**
* Update the node index values in the cache.
* This updates the RDD and its lineage.
* TODO: Passing bin information to executors seems unnecessary and costly.
* @param data The RDD of training rows.
* @param nodeIdUpdaters A map of node index updaters.
* The key is the indices of nodes that we want to update.
* @param bins Bin information needed to find child node indices.
*/
def updateNodeIndices(
data: RDD[BaggedPoint[TreePoint]],
nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
bins: Array[Array[Bin]]): Unit = {
if (prevNodeIdsForInstances != null) {
// Unpersist the previous one if one exists.
prevNodeIdsForInstances.unpersist()
}
prevNodeIdsForInstances = nodeIdsForInstances
nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
dataPoint => {
var treeId = 0
while (treeId < nodeIdUpdaters.length) {
val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
if (nodeIdUpdater != null) {
val newNodeIndex = nodeIdUpdater.updateNodeIndex(
binnedFeatures = dataPoint._1.datum.binnedFeatures,
bins = bins)
dataPoint._2(treeId) = newNodeIndex
}
treeId += 1
}
dataPoint._2
}
}
// Keep on persisting new ones.
nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
rddUpdateCount += 1
// Handle checkpointing if the directory is not None.
if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
(rddUpdateCount % checkpointInterval) == 0) {
// Let's see if we can delete previous checkpoints.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
// We can delete the oldest checkpoint iff
// the next checkpoint actually exists in the file system.
if (checkpointQueue.get(1).get.getCheckpointFile != None) {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark,
// we'll manually delete it here.
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
fs.delete(new Path(old.getCheckpointFile.get), true)
} else {
canDelete = false
}
}
nodeIdsForInstances.checkpoint()
checkpointQueue.enqueue(nodeIdsForInstances)
}
}
/**
* Call this after training is finished to delete any remaining checkpoints.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.size > 0) {
val old = checkpointQueue.dequeue()
if (old.getCheckpointFile != None) {
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
fs.delete(new Path(old.getCheckpointFile.get), true)
}
}
}
}
@DeveloperApi
private[tree] object NodeIdCache {
/**
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for.
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache.
* @return A node Id cache containing an RDD of initial root node Indices.
*/
def init(
data: RDD[BaggedPoint[TreePoint]],
numTrees: Int,
checkpointDir: Option[String],
checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = {
new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)),
checkpointDir,
checkpointInterval)
}
}
......@@ -34,18 +34,11 @@ import org.apache.spark.mllib.util.LocalSparkContext
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends FunSuite with LocalSparkContext {
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
assert(rf.weakHypotheses.size === 1)
......@@ -60,18 +53,27 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(rfTree.toString == dt.toString)
}
test("Regression with continuous features:" +
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
binaryClassificationTestWithContinuousFeatures(strategy)
}
test("Binary classification with continuous features and node Id cache :" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
binaryClassificationTestWithContinuousFeatures(strategy)
}
def regressionTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
val strategy = new Strategy(algo = Regression, impurity = Variance,
maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo)
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
assert(rf.weakHypotheses.size === 1)
......@@ -86,14 +88,28 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(rfTree.toString == dt.toString)
}
test("Binary classification with continuous features: subsampling features") {
test("Regression with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Regression, impurity = Variance,
maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo)
regressionTestWithContinuousFeatures(strategy)
}
test("Regression with continuous features and node Id cache :" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Regression, impurity = Variance,
maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
regressionTestWithContinuousFeatures(strategy)
}
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
......@@ -149,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
test("Binary classification with continuous features: subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
test("Binary classification with continuous features and node Id cache: subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
......@@ -164,7 +194,6 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment