Skip to content
Snippets Groups Projects
Commit deb41133 authored by AiHe's avatar AiHe Committed by Joseph K. Bradley
Browse files

[SPARK-7473] [MLLIB] Add reservoir sample in RandomForest

reservoir feature sample by using existing api

Author: AiHe <ai.he@ussuning.com>

Closes #5988 from AiHe/reservoir and squashes the following commits:

e7a41ac [AiHe] remove non-robust testing case
28ffb9a [AiHe] set seed as rng.nextLong
37459e1 [AiHe] set fixed seed
1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest
parent d7b69946
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.random.SamplingUtils
/**
* :: Experimental ::
......@@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
// TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir)
Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
.take(metadata.numFeaturesPerNode).toArray)
Some(SamplingUtils.reservoirSampleAndCount(Range(0,
metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
}
......
......@@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
test("subsampling rate in RandomForest"){
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment