Skip to content
Snippets Groups Projects
Commit db42451a authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #643 from adatao/master

Bug fix: Zero-length partitions result in NaN for overall mean & variance
parents e82a2ffc f91195cc
No related branches found
No related tags found
No related merge requests found
...@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { ...@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (other == this) { if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order merge(other.copy()) // Avoid overwriting fields in a weird order
} else { } else {
val delta = other.mu - mu if (n == 0) {
if (other.n * 10 < n) { mu = other.mu
mu = mu + (delta * other.n) / (n + other.n) m2 = other.m2
} else if (n * 10 < other.n) { n = other.n
mu = other.mu - (delta * n) / (n + other.n) } else if (other.n != 0) {
} else { val delta = other.mu - mu
mu = (mu * n + other.mu * other.n) / (n + other.n) if (other.n * 10 < n) {
mu = mu + (delta * other.n) / (n + other.n)
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n)
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
} }
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) this
n += other.n
this
} }
} }
......
package spark package spark
import org.scalatest.FunSuite import org.scalatest.FunSuite
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import SparkContext._ import SparkContext._
import spark.util.StatCounter
import scala.math.abs
class PartitioningSuite extends FunSuite with LocalSparkContext { class PartitioningSuite extends FunSuite with LocalSparkContext {
...@@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { ...@@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
} }
test("Zero-length partitions should be correctly handled") {
// Create RDD with some consecutive empty partitions (including the "first" one)
sc = new SparkContext("local", "test")
val rdd: RDD[Double] = sc
.parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
val stats: StatCounter = rdd.stats();
assert(abs(6.0 - stats.sum) < 0.01);
assert(abs(6.0/2 - rdd.mean) < 0.01);
assert(abs(1.0 - rdd.variance) < 0.01);
assert(abs(1.0 - rdd.stdev) < 0.01);
// Add other tests here for classes that should be able to handle empty partitions correctly
}
} }
...@@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") ...@@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment