Skip to content
Snippets Groups Projects
Commit 01dd1f5c authored by Yong Tang's avatar Yong Tang Committed by Xiangrui Meng
Browse files

[SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature...

[SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size instead of regexes

## What changes were proposed in this pull request?

This fix tries to change RandomForest's supported strategies from using regexes to using parseInt and
parseDouble, for the purpose of robustness and maintainability.

## How was this patch tested?

Existing tests passed.

Author: Yong Tang <yong.tang.github@outlook.com>

Closes #12360 from yongtang/SPARK-14565.
parent d7e124ed
No related branches found
No related tags found
No related merge requests found
......@@ -18,8 +18,10 @@
package org.apache.spark.ml.tree.impl
import scala.collection.mutable
import scala.util.Try
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.RandomForestParams
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
......@@ -184,15 +186,22 @@ private[spark] object DecisionTreeMetadata extends Logging {
case _ => featureSubsetStrategy
}
val isIntRegex = "^([1-9]\\d*)$".r
val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
case _ =>
Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match {
case Some(value) => math.min(value, numFeatures)
case None =>
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
case Some(value) => math.ceil(value * numFeatures).toInt
case _ => throw new IllegalArgumentException(s"Supported values:" +
s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")
}
}
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
......
......@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
import scala.util.Try
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
......@@ -346,10 +348,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
|| value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
setDefault(featureSubsetStrategy -> "auto")
......@@ -396,9 +400,6 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
// The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
private[ml] trait RandomForestClassifierParams
......
......@@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
......@@ -76,9 +77,10 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
|| featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
|| Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
|| Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")
/**
......
......@@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
intercept[IllegalArgumentException]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
}
......@@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
}
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
intercept[IllegalArgumentException]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment