diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index dd9a5f261f60f86c0190fe5b063c0b18f0b0518a..afbb9d974d42a22b50614896d41ec4bafa730122 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -477,8 +477,8 @@ private[ml] object RandomForest extends Logging { // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) } new DTStatsAggregator(metadata, featuresForNode) } @@ -827,8 +827,8 @@ private[ml] object RandomForest extends Logging { val numFeatures = metadata.numFeatures // Sample the input only if there are continuous features. - val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) - val sampledInput = if (hasContinuousFeatures) { + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) val fraction = if (requiredSamples < metadata.numExamples) { @@ -837,58 +837,57 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { - new Array[LabeledPoint](0) + input.sparkContext.emptyRDD[LabeledPoint] } - val splits = new Array[Array[Split]](numFeatures) - - // Find all splits. - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isContinuous(featureIndex)) { - val featureSamples = sampledInput.map(_.features(featureIndex)) - val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex) + findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures) + } - val numSplits = featureSplits.length - logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") - splits(featureIndex) = new Array[Split](numSplits) + private def findSplitsBinsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { + + val continuousSplits: scala.collection.Map[Int, Array[Split]] = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (idx, samples) => + val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + }.collectAsMap() + } - var splitIndex = 0 - while (splitIndex < numSplits) { - val threshold = featureSplits(splitIndex) - splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold) - splitIndex += 1 - } - } else { - // Categorical feature - if (metadata.isUnordered(featureIndex)) { - val numSplits = metadata.numSplits(featureIndex) - val featureArity = metadata.featureArity(featureIndex) - // TODO: Use an implicit representation mapping each category to a subset of indices. - // I.e., track indices such that we can calculate the set of bins for which - // feature value x splits to the left. - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - splits(featureIndex) = new Array[Split](numSplits) - var splitIndex = 0 - while (splitIndex < numSplits) { - val categories: List[Double] = - extractMultiClassCategories(splitIndex + 1, featureArity) - splits(featureIndex)(splitIndex) = - new CategoricalSplit(featureIndex, categories.toArray, featureArity) - splitIndex += 1 - } - } else { - // Ordered features - // Bins correspond to feature values, so we do not need to compute splits or bins - // beforehand. Splits are constructed as needed during training. - splits(featureIndex) = new Array[Split](0) + val numFeatures = metadata.numFeatures + val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { + case i if metadata.isContinuous(i) => + val split = continuousSplits(i) + metadata.setNumSplits(i, split.length) + split + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(i) + Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(i, categories.toArray, featureArity) } - } - featureIndex += 1 + + case i if metadata.isCategorical(i) => + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + Array.empty[Split] } splits } @@ -930,7 +929,7 @@ private[ml] object RandomForest extends Logging { * @return array of splits */ private[tree] def findSplitsForContinuousFeature( - featureSamples: Array[Double], + featureSamples: Iterable[Double], metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), @@ -940,8 +939,9 @@ private[ml] object RandomForest extends Logging { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value - val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) + val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) } // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -952,7 +952,7 @@ private[ml] object RandomForest extends Logging { valueCounts.map(_._1) } else { // stride between splits - val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + val stride: Double = numSamples.toDouble / (numSplits + 1) logDebug("stride = " + stride) // iterate `valueCount` to find splits @@ -988,8 +988,6 @@ private[ml] object RandomForest extends Logging { assert(splits.length > 0, s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + " Please remove this feature and then try again.") - // set number of splits accordingly - metadata.setNumSplits(featureIndex, splits.length) splits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c0934d241f50a59018706267729d61e0191448db..8f02e098acc304a0fa51e980af65960611953c5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1010,7 +1010,7 @@ object DecisionTree extends Serializable with Logging { featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = { val splits = { val featureSplits = findSplitsForContinuousFeature( - featureSamples.toArray, + featureSamples, metadata, featureIndex) logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}") @@ -1115,7 +1115,7 @@ object DecisionTree extends Serializable with Logging { * @return Array of splits. */ private[tree] def findSplitsForContinuousFeature( - featureSamples: Array[Double], + featureSamples: Iterable[Double], metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), @@ -1125,8 +1125,9 @@ object DecisionTree extends Serializable with Logging { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value - val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) + val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) } // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1137,7 +1138,7 @@ object DecisionTree extends Serializable with Logging { valueCounts.map(_._1) } else { // stride between splits - val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + val stride: Double = numSamples.toDouble / (numSplits + 1) logDebug("stride = " + stride) // iterate `valueCount` to find splits