Skip to content
Snippets Groups Projects
Commit 9023015f authored by Marcin Tustin's avatar Marcin Tustin Committed by Sean Owen
Browse files

[SPARK-14163][CORE] SumEvaluator and countApprox cannot reliably handle RDDs of size 1

## What changes were proposed in this pull request?

This special cases 0 and 1 counts to avoid passing 0 degrees of freedom.

## How was this patch tested?

Tests run successfully. New test added.

## Note:
This recreates #11982 which was closed to due to non-updated diff. rxin srowen Commented there.
This also adds tests, reworks the code to perform the special casing (based on srowen's comments), and adds equality machinery for BoundedDouble, as well as changing how it is transformed to string.

Author: Marcin Tustin <mtustin@handybook.com>
Author: Marcin Tustin <mtustin@handy.com>

Closes #12016 from mtustin-handy/SPARK-14163.
parent c238cd07
No related branches found
No related tags found
No related merge requests found
......@@ -21,5 +21,23 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
override def hashCode: Int =
this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode
/**
* Note that consistent with Double, any NaN value will make equality false
*/
override def equals(that: Any): Boolean =
that match {
case that: BoundedDouble => {
this.mean == that.mean &&
this.confidence == that.confidence &&
this.low == that.low &&
this.high == that.high
}
case _ => false
}
}
......@@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
// modified in merge
var outputsMerged = 0
var counter = new StatCounter
val counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
......@@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
} else if (outputsMerged == 0) {
} else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {
val meanVar = counter.sampleVariance / counter.count
// branch at this point because counter.count == 1 implies counter.sampleVariance == Nan
// and we don't want to ever return a bound of NaN
if (meanVar.isNaN || counter.count == 1) {
new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = if (counter.count > 100) {
new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
// note that if this goes to 0, TDistribution will throw an exception.
// Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
new BoundedDouble(sumEstimate, confidence, low, high)
}
val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.partial
import org.apache.spark._
import org.apache.spark.util.StatCounter
class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext {
test("correct handling of count 1") {
// setup
val counter = new StatCounter(List(2.0))
// count of 10 because it's larger than 1,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)
// execute
val res = evaluator.currentResult()
// 38.0 - 7.1E-15 because that's how the maths shakes out
val targetMean = 38.0 - 7.1E-15
// Sanity check that equality works on BoundedDouble
assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2))
// actual test
assert(res ==
new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity))
}
test("correct handling of count 0") {
// setup
val counter = new StatCounter(List())
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)
// execute
val res = evaluator.currentResult()
// assert
assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity))
}
test("correct handling of NaN") {
// setup
val counter = new StatCounter(List(1, Double.NaN, 2))
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)
// execute
val res = evaluator.currentResult()
// assert - note semantics of == in face of NaN
assert(res.mean.isNaN)
assert(res.confidence == 0.95)
assert(res.low == Double.NegativeInfinity)
assert(res.high == Double.PositiveInfinity)
}
test("correct handling of > 1 values") {
// setup
val counter = new StatCounter(List(1, 3, 2))
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)
// execute
val res = evaluator.currentResult()
// These vals because that's how the maths shakes out
val targetMean = 78.0
val targetLow = -117.617 + 2.732357258139473E-5
val targetHigh = 273.617 - 2.7323572624027292E-5
val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh)
// check that values are within expected tolerance of expectation
assert(res == target)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment