Skip to content
Snippets Groups Projects
Commit 734ed7a7 authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-21806][MLLIB] BinaryClassificationMetrics pr(): first point (0.0, 1.0) is misleading

## What changes were proposed in this pull request?

Prepend (0,p) to precision-recall curve not (0,1) where p matches lowest recall point

## How was this patch tested?

Updated tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #19038 from srowen/SPARK-21806.
parent 8f0df6bc
No related branches found
No related tags found
No related merge requests found
...@@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") ( ...@@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") (
/** /**
* Returns the precision-recall curve, which is an RDD of (recall, precision), * Returns the precision-recall curve, which is an RDD of (recall, precision),
* NOT (precision, recall), with (0.0, 1.0) prepended to it. * NOT (precision, recall), with (0.0, p) prepended to it, where p is the precision
* associated with the lowest recall on the curve.
* @see <a href="http://en.wikipedia.org/wiki/Precision_and_recall"> * @see <a href="http://en.wikipedia.org/wiki/Precision_and_recall">
* Precision and recall (Wikipedia)</a> * Precision and recall (Wikipedia)</a>
*/ */
@Since("1.0.0") @Since("1.0.0")
def pr(): RDD[(Double, Double)] = { def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision) val prCurve = createCurve(Recall, Precision)
val sc = confusions.context val (_, firstPrecision) = prCurve.first()
val first = sc.makeRDD(Seq((0.0, 1.0)), 1) confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
first.union(prCurve)
} }
/** /**
......
...@@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._ ...@@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 private def assertSequencesMatch(actual: Seq[Double], expected: Seq[Double]): Unit = {
actual.zip(expected).foreach { case (a, e) => assert(a ~== e absTol 1.0e-5) }
private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
assert(left.zip(right).forall(areWithinEpsilon))
} }
private def assertTupleSequencesMatch(left: Seq[(Double, Double)], private def assertTupleSequencesMatch(actual: Seq[(Double, Double)],
right: Seq[(Double, Double)]): Unit = { expected: Seq[(Double, Double)]): Unit = {
assert(left.zip(right).forall(pairsWithinEpsilon)) actual.zip(expected).foreach { case ((ax, ay), (ex, ey)) =>
assert(ax ~== ex absTol 1.0e-5)
assert(ay ~== ey absTol 1.0e-5)
}
} }
private def validateMetrics(metrics: BinaryClassificationMetrics, private def validateMetrics(metrics: BinaryClassificationMetrics,
...@@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark ...@@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
expectedFMeasures1: Seq[Double], expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double], expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double], expectedPrecisions: Seq[Double],
expectedRecalls: Seq[Double]) = { expectedRecalls: Seq[Double]): Unit = {
assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
...@@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark ...@@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
val fpr = Seq(1.0) val fpr = Seq(1.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions) val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr val prCurve = Seq((0.0, 0.0)) ++ pr
val f1 = pr.map { val f1 = pr.map {
case (0, 0) => 0.0 case (0, 0) => 0.0
case (r, p) => 2.0 * (p * r) / (p + r) case (r, p) => 2.0 * (p * r) / (p + r)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment