Skip to content
Snippets Groups Projects
Commit 3d194cc7 authored by Sean Owen's avatar Sean Owen Committed by Xiangrui Meng
Browse files

SPARK-4547 [MLLIB] OOM when making bins in BinaryClassificationMetrics

Now that I've implemented the basics here, I'm less convinced there is a need for this change, somehow. Callers can downsample before or after. Really the OOM is not in the ROC curve code, but in code that might `collect()` it for local analysis. Still, might be useful to down-sample since the ROC curve probably never needs millions of points.

This is a first pass. Since the `(score,label)` are already grouped and sorted, I think it's sufficient to just take every Nth such pair, in order to downsample by a factor of N? this is just like retaining every Nth point on the curve, which I think is the goal. All of the data is still used to build the curve of course.

What do you think about the API, and usefulness?

Author: Sean Owen <sowen@cloudera.com>

Closes #3702 from srowen/SPARK-4547 and squashes the following commits:

1d34d05 [Sean Owen] Indent and reorganize numBins scaladoc
692d825 [Sean Owen] Change handling of large numBins, make 2nd consturctor instead of optional param, style change
a03610e [Sean Owen] Add downsamplingFactor to BinaryClassificationMetrics
parent 8e14c5eb
No related branches found
No related tags found
No related merge requests found
......@@ -28,9 +28,30 @@ import org.apache.spark.rdd.{RDD, UnionRDD}
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
* in the input, and this could be as large as the input itself -- millions of
* points or more, when thousands may be entirely sufficient to summarize
* the curve. After down-sampling, the curves will instead be made of approximately
* `numBins` points instead. Points are made from bins of equal numbers of
* consecutive points. The size of each bin is
* `floor(scoreAndLabels.count() / numBins)`, which means the resulting number
* of bins may not exactly equal numBins. The last bin in each partition may
* be smaller as a result, meaning there may be an extra sample at
* partition boundaries.
*/
@Experimental
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging {
class BinaryClassificationMetrics(
val scoreAndLabels: RDD[(Double, Double)],
val numBins: Int) extends Logging {
require(numBins >= 0, "numBins must be nonnegative")
/**
* Defaults `numBins` to 0.
*/
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
......@@ -103,7 +124,39 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)
val agg = counts.values.mapPartitions { iter =>
val binnedCounts =
// Only down-sample if bins is > 0
if (numBins == 0) {
// Use original directly
counts
} else {
val countsSize = counts.count()
// Group the iterator into chunks of about countsSize / numBins points,
// so that the resulting number of bins is about numBins
var grouping = countsSize / numBins
if (grouping < 2) {
// numBins was more than half of the size; no real point in down-sampling to bins
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
counts
} else {
if (grouping >= Int.MaxValue) {
logWarning(
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
grouping = Int.MaxValue
}
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
// The score of the combined point will be just the first one's score
val firstScore = pairs.head._1
// The point will contain all counts in this chunk
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
(firstScore, agg)
})
}
}
val agg = binnedCounts.values.mapPartitions { iter =>
val agg = new BinaryLabelCounter()
iter.foreach(agg += _)
Iterator(agg)
......@@ -113,7 +166,7 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends
(agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c)
val totalCount = partitionwiseCumulativeCounts.last
logInfo(s"Total counts: $totalCount")
val cumulativeCounts = counts.mapPartitionsWithIndex(
val cumulativeCounts = binnedCounts.mapPartitionsWithIndex(
(index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => {
val cumCount = partitionwiseCumulativeCounts(index)
iter.map { case (score, c) =>
......
......@@ -124,4 +124,40 @@ class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkConte
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}
test("binary evaluation metrics with downsampling") {
val scoreAndLabels = Seq(
(0.1, 0.0), (0.2, 0.0), (0.3, 1.0), (0.4, 0.0), (0.5, 0.0),
(0.6, 1.0), (0.7, 1.0), (0.8, 0.0), (0.9, 1.0))
val scoreAndLabelsRDD = sc.parallelize(scoreAndLabels, 1)
val original = new BinaryClassificationMetrics(scoreAndLabelsRDD)
val originalROC = original.roc().collect().sorted.toList
// Add 2 for (0,0) and (1,1) appended at either end
assert(2 + scoreAndLabels.size == originalROC.size)
assert(
List(
(0.0, 0.0), (0.0, 0.25), (0.2, 0.25), (0.2, 0.5), (0.2, 0.75),
(0.4, 0.75), (0.6, 0.75), (0.6, 1.0), (0.8, 1.0), (1.0, 1.0),
(1.0, 1.0)
) ==
originalROC)
val numBins = 4
val downsampled = new BinaryClassificationMetrics(scoreAndLabelsRDD, numBins)
val downsampledROC = downsampled.roc().collect().sorted.toList
assert(
// May have to add 1 if the sample factor didn't divide evenly
2 + (numBins + (if (scoreAndLabels.size % numBins == 0) 0 else 1)) ==
downsampledROC.size)
assert(
List(
(0.0, 0.0), (0.2, 0.25), (0.2, 0.75), (0.6, 0.75), (0.8, 1.0),
(1.0, 1.0), (1.0, 1.0)
) ==
downsampledROC)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment