Skip to content
Snippets Groups Projects
Commit 4f4721a2 authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-14862][ML] Updated Classifiers to not require labelCol metadata

## What changes were proposed in this pull request?

Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata.
* They first check for metadata.
* If numClasses is not specified in metadata, they identify the largest label value (up to a limit).

This functionality is implemented in a new Classifier.getNumClasses method.

Also
* Updated Classifier.extractLabeledPoints to (a) check label values and (b) include a second version which takes a numClasses value for validity checking.

## How was this patch tested?

* Unit tests in ClassifierSuite for helper methods
* Unit tests for DecisionTreeClassifier, RandomForestClassifier, GBTClassifier with toy datasets lacking label metadata

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12663 from jkbradley/trees-no-metadata.
parent dae538a4
No related branches found
No related tags found
No related merge requests found
Showing with 245 additions and 31 deletions
......@@ -17,14 +17,17 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* (private[spark]) Params for classification.
......@@ -62,6 +65,67 @@ abstract class Classifier[
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
/**
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
* and features ([[Vector]]). Labels are cast to [[DoubleType]].
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @throws SparkException if any label is not an integer >= 0
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, 1, ..., $numClasses), where numClasses=$numClasses.")
LabeledPoint(label, features)
}
}
/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
* by finding the maximum label value.
*
* Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
* such as in [[extractLabeledPoints()]].
*
* @param dataset Dataset which contains a column [[labelCol]]
* @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
* is specified in the metadata, then maxNumClasses is ignored.
* @return number of classes
* @throws IllegalArgumentException if metadata does not specify numClasses, and the
* actual numClasses exceeds maxNumClasses
*/
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None =>
// Get number of classes from dataset itself.
val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
if (maxLabelRow.isEmpty) {
throw new SparkException("ML algorithm was given empty dataset.")
}
val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
val numClasses = maxDoubleLabel.toInt + 1
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
s" StringIndexer to the label column.")
logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
numClasses
}
}
}
/**
......
......@@ -85,14 +85,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
......
......@@ -35,8 +35,9 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
......@@ -126,16 +127,16 @@ class GBTClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("GBTClassifier was given input" +
s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
require(numClasses == 2,
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val oldDataset: RDD[LabeledPoint] =
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
s" GBTClassifier currently only supports binary classification.")
LabeledPoint(label, features)
}
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
......@@ -165,6 +166,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
* model for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
......@@ -185,6 +187,7 @@ class GBTClassificationModel private[ml](
/**
* Construct a GBTClassificationModel
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
......
......@@ -101,14 +101,8 @@ class RandomForestClassifier @Since("1.4.0") (
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val trees =
......
......@@ -17,6 +17,86 @@
package org.apache.spark.ml.classification
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("extractLabeledPoints") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
}
val c = new MockClassifier
// Valid dataset
val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
c.extractLabeledPoints(df0, 6).count()
// Invalid datasets
val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0))
withClue("Classifier should fail if label is negative") {
val e: SparkException = intercept[SparkException] {
c.extractLabeledPoints(df1, 6).count()
}
assert(e.getMessage.contains("given dataset with invalid label"))
}
val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0))
withClue("Classifier should fail if label is not an integer") {
val e: SparkException = intercept[SparkException] {
c.extractLabeledPoints(df2, 6).count()
}
assert(e.getMessage.contains("given dataset with invalid label"))
}
// extractLabeledPoints with numClasses specified
withClue("Classifier should fail if label is >= numClasses") {
val e: SparkException = intercept[SparkException] {
c.extractLabeledPoints(df0, numClasses = 5).count()
}
assert(e.getMessage.contains("given dataset with invalid label"))
}
withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
c.extractLabeledPoints(df0, numClasses = 0).count()
}
assert(e.getMessage.contains("but requires numClasses > 0"))
}
}
test("getNumClasses") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
}
val c = new MockClassifier
// Valid dataset
val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0))
assert(c.getNumClasses(df0) === 6)
// Invalid datasets
val df1 = getTestData(Seq(0.0, 2.0, 1.0, 5.1))
withClue("getNumClasses should fail if label is max label not an integer") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
c.getNumClasses(df1)
}
assert(e.getMessage.contains("requires integers in range"))
}
val df2 = getTestData(Seq(0.0, 2.0, 1.0, Int.MaxValue.toDouble))
withClue("getNumClasses should fail if label is max label is >= Int.MaxValue") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
c.getNumClasses(df2)
}
assert(e.getMessage.contains("requires integers in range"))
}
}
}
object ClassifierSuite {
/**
......@@ -29,4 +109,32 @@ object ClassifierSuite {
"rawPredictionCol" -> "myRawPrediction"
)
class MockClassifier(override val uid: String)
extends Classifier[Vector, MockClassifier, MockClassificationModel] {
def this() = this(Identifiable.randomUID("mockclassifier"))
override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
override def train(dataset: Dataset[_]): MockClassificationModel =
throw new NotImplementedError()
// Make methods public
override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
super.extractLabeledPoints(dataset, numClasses)
def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset)
}
class MockClassificationModel(override val uid: String)
extends ClassificationModel[Vector, MockClassificationModel] {
def this() = this(Identifiable.randomUID("mockclassificationmodel"))
protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
override def numClasses: Int = throw new NotImplementedError()
}
}
......@@ -342,6 +342,12 @@ class DecisionTreeClassifierSuite
}
}
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val dt = new DecisionTreeClassifier().setMaxDepth(1)
dt.fit(df)
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
......
......@@ -17,12 +17,13 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
......@@ -128,6 +129,43 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
}
*/
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
gbt.fit(df)
}
test("extractLabeledPoints with bad data") {
def getTestData(labels: Seq[Double]): DataFrame = {
val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }
sqlContext.createDataFrame(data)
}
val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1)
// Invalid datasets
val df1 = getTestData(Seq(0.0, -1.0, 1.0, 0.0))
withClue("Classifier should fail if label is negative") {
val e: SparkException = intercept[SparkException] {
gbt.fit(df1)
}
assert(e.getMessage.contains("currently only supports binary classification"))
}
val df2 = getTestData(Seq(0.0, 0.1, 1.0, 0.0))
withClue("Classifier should fail if label is not an integer") {
val e: SparkException = intercept[SparkException] {
gbt.fit(df2)
}
assert(e.getMessage.contains("currently only supports binary classification"))
}
val df3 = getTestData(Seq(0.0, 2.0, 1.0, 0.0))
withClue("Classifier should fail if label is >= 2") {
val e: SparkException = intercept[SparkException] {
gbt.fit(df3)
}
assert(e.getMessage.contains("currently only supports binary classification"))
}
}
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
......
......@@ -154,9 +154,16 @@ class RandomForestClassifierSuite
}
}
test("Fitting without numClasses in metadata") {
val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc))
val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)
rf.fit(df)
}
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment