Skip to content
Snippets Groups Projects
Commit 2a5c9307 authored by BenFradet's avatar BenFradet Committed by Nick Pentreath
Browse files

[SPARK-13962][ML] spark.ml Evaluators should support other numeric types for label

## What changes were proposed in this pull request?

Made BinaryClassificationEvaluator, MulticlassClassificationEvaluator and RegressionEvaluator accept all numeric types for label

## How was this patch tested?

Unit tests

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #12500 from BenFradet/SPARK-13962.
parent f8709218
No related branches found
No related tags found
No related merge requests found
Showing with 88 additions and 51 deletions
...@@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I ...@@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.DoubleType
/** /**
...@@ -73,13 +74,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va ...@@ -73,13 +74,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
override def evaluate(dataset: Dataset[_]): Double = { override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol))
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { val scoreAndLabels =
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
} case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = $(metricName) match { val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC() case "areaUnderROC" => metrics.areaUnderROC()
......
...@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} ...@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.DoubleType
/** /**
...@@ -72,12 +73,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid ...@@ -72,12 +73,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
override def evaluate(dataset: Dataset[_]): Double = { override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol))
val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { val predictionAndLabels =
case Row(prediction: Double, label: Double) => dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
(prediction, label) case Row(prediction: Double, label: Double) => (prediction, label)
} }
val metrics = new MulticlassMetrics(predictionAndLabels) val metrics = new MulticlassMetrics(predictionAndLabels)
val metric = $(metricName) match { val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure case "f1" => metrics.weightedFMeasure
......
...@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation ...@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
...@@ -74,22 +74,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui ...@@ -74,22 +74,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("2.0.0") @Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = { override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema val schema = dataset.schema
val predictionColName = $(predictionCol) SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
val predictionType = schema($(predictionCol)).dataType SchemaUtils.checkNumericType(schema, $(labelCol))
require(predictionType == FloatType || predictionType == DoubleType,
s"Prediction column $predictionColName must be of type float or double, " +
s" but not $predictionType")
val labelColName = $(labelCol)
val labelType = schema($(labelCol)).dataType
require(labelType == FloatType || labelType == DoubleType,
s"Label column $labelColName must be of type float or double, but not $labelType")
val predictionAndLabels = dataset val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
.rdd. .rdd
map { case Row(prediction: Double, label: Double) => .map { case Row(prediction: Double, label: Double) => (prediction, label) }
(prediction, label)
}
val metrics = new RegressionMetrics(predictionAndLabels) val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match { val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError case "rmse" => metrics.rootMeanSquaredError
......
...@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation ...@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
...@@ -68,4 +68,9 @@ class BinaryClassificationEvaluatorSuite ...@@ -68,4 +68,9 @@ class BinaryClassificationEvaluatorSuite
"equal to one of the following types: [DoubleType, ") "equal to one of the following types: [DoubleType, ")
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.")
} }
test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
}
} }
...@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation ...@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
class MulticlassClassificationEvaluatorSuite class MulticlassClassificationEvaluatorSuite
...@@ -36,4 +36,8 @@ class MulticlassClassificationEvaluatorSuite ...@@ -36,4 +36,8 @@ class MulticlassClassificationEvaluatorSuite
.setMetricName("recall") .setMetricName("recall")
testDefaultReadWrite(evaluator) testDefaultReadWrite(evaluator)
} }
test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext)
}
} }
...@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation ...@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
...@@ -83,4 +83,8 @@ class RegressionEvaluatorSuite ...@@ -83,4 +83,8 @@ class RegressionEvaluatorSuite
.setMetricName("r2") .setMetricName("r2")
testDefaultReadWrite(evaluator) testDefaultReadWrite(evaluator)
} }
test("should support all NumericType labels and not support other types") {
MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
}
} }
...@@ -79,16 +79,21 @@ private[ml] object TreeTests extends SparkFunSuite { ...@@ -79,16 +79,21 @@ private[ml] object TreeTests extends SparkFunSuite {
* This must be non-empty. * This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as continuous. * @param numClasses Number of classes label can take. If 0, mark as continuous.
* @param labelColName Name of the label column on which to set the metadata. * @param labelColName Name of the label column on which to set the metadata.
* @param featuresColName Name of the features column
* @return DataFrame with metadata * @return DataFrame with metadata
*/ */
def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = { def setMetadata(
data: DataFrame,
numClasses: Int,
labelColName: String,
featuresColName: String): DataFrame = {
val labelAttribute = if (numClasses == 0) { val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName(labelColName) NumericAttribute.defaultAttr.withName(labelColName)
} else { } else {
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
} }
val labelMetadata = labelAttribute.toMetadata() val labelMetadata = labelAttribute.toMetadata()
data.select(data("features"), data(labelColName).as(labelColName, labelMetadata)) data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata))
} }
/** /**
......
...@@ -19,6 +19,7 @@ package org.apache.spark.ml.util ...@@ -19,6 +19,7 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
...@@ -47,12 +48,30 @@ object MLTestingUtils extends SparkFunSuite { ...@@ -47,12 +48,30 @@ object MLTestingUtils extends SparkFunSuite {
val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
actuals.foreach(actual => check(expected, actual)) actuals.foreach(actual => check(expected, actual))
val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) val dfWithStringLabels = sqlContext.createDataFrame(Seq(
("0", Vectors.dense(0, 2, 3), 0.0)
)).toDF("label", "features", "censor")
val thrown = intercept[IllegalArgumentException] { val thrown = intercept[IllegalArgumentException] {
estimator.fit(dfWithStringLabels) estimator.fit(dfWithStringLabels)
} }
assert(thrown.getMessage contains assert(thrown.getMessage.contains(
"Column label must be of type NumericType but was actually of type StringType") "Column label must be of type NumericType but was actually of type StringType"))
}
def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t)))
actuals.foreach(actual => assert(expected === actual))
val dfWithStringLabels = sqlContext.createDataFrame(Seq(
("0", 0d)
)).toDF("label", "prediction")
val thrown = intercept[IllegalArgumentException] {
evaluator.evaluate(dfWithStringLabels)
}
assert(thrown.getMessage.contains(
"Column label must be of type NumericType but was actually of type StringType"))
} }
def genClassifDFWithNumericLabelCol( def genClassifDFWithNumericLabelCol(
...@@ -69,9 +88,10 @@ object MLTestingUtils extends SparkFunSuite { ...@@ -69,9 +88,10 @@ object MLTestingUtils extends SparkFunSuite {
val types = val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) types.map { t =>
.map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
.toMap t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
}.toMap
} }
def genRegressionDFWithNumericLabelCol( def genRegressionDFWithNumericLabelCol(
...@@ -89,24 +109,29 @@ object MLTestingUtils extends SparkFunSuite { ...@@ -89,24 +109,29 @@ object MLTestingUtils extends SparkFunSuite {
val types = val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types types.map { t =>
.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
.map { case (t, d) => t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) .withColumn(censorColName, lit(0.0))
} }.toMap
.toMap
} }
def generateDFWithStringLabelCol( def genEvaluatorDFWithNumericLabelCol(
sqlContext: SQLContext, sqlContext: SQLContext,
labelColName: String = "label", labelColName: String = "label",
featuresColName: String = "features", predictionColName: String = "prediction"): Map[NumericType, DataFrame] = {
censorColName: String = "censor"): DataFrame = val df = sqlContext.createDataFrame(Seq(
sqlContext.createDataFrame(Seq( (0, 0d),
("0", Vectors.dense(0, 2, 3), 0.0), (1, 1d),
("1", Vectors.dense(0, 3, 1), 1.0), (2, 2d),
("0", Vectors.dense(0, 2, 2), 0.0), (3, 3d),
("1", Vectors.dense(0, 3, 9), 1.0), (4, 4d)
("0", Vectors.dense(0, 2, 6), 0.0) )).toDF(labelColName, predictionColName)
)).toDF(labelColName, featuresColName, censorColName)
val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types
.map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName)))
.toMap
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment