From 7f63e85b47a93434030482160e88fe63bf9cff4e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu <shixiong@databricks.com> Date: Wed, 2 Aug 2017 10:59:59 -0700 Subject: [PATCH] [SPARK-21597][SS] Fix a potential overflow issue in EventTimeStats ## What changes were proposed in this pull request? This PR fixed a potential overflow issue in EventTimeStats. ## How was this patch tested? The new unit tests Author: Shixiong Zhu <shixiong@databricks.com> Closes #18803 from zsxwing/avg. --- .../streaming/EventTimeWatermarkExec.scala | 10 ++--- .../streaming/ProgressReporter.scala | 2 +- .../streaming/EventTimeWatermarkSuite.scala | 41 ++++++++++++++++++- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 87e5b78550..b161651c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -27,27 +27,25 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.AccumulatorV2 /** Class for collecting event time stats with an accumulator */ -case class EventTimeStats(var max: Long, var min: Long, var sum: Long, var count: Long) { +case class EventTimeStats(var max: Long, var min: Long, var avg: Double, var count: Long) { def add(eventTime: Long): Unit = { this.max = math.max(this.max, eventTime) this.min = math.min(this.min, eventTime) - this.sum += eventTime this.count += 1 + this.avg += (eventTime - avg) / count } def merge(that: EventTimeStats): Unit = { this.max = math.max(this.max, that.max) this.min = math.min(this.min, that.min) - this.sum += that.sum this.count += that.count + this.avg += (that.avg - this.avg) * that.count / this.count } - - def avg: Long = sum / count } object EventTimeStats { def zero: EventTimeStats = EventTimeStats( - max = Long.MinValue, min = Long.MaxValue, sum = 0L, count = 0L) + max = Long.MinValue, min = Long.MaxValue, avg = 0.0, count = 0L) } /** Accumulator that collects stats on event time in a batch. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1887b07c49..c5fbb638e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -259,7 +259,7 @@ trait ProgressReporter extends Logging { Map( "max" -> stats.max, "min" -> stats.min, - "avg" -> stats.avg).mapValues(formatTimestamp) + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp ExecutionStats(numInputRows, stateOperators, eventTimeStats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 1b60a06ec4..552911f32e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -21,7 +21,7 @@ import java.{util => ju} import java.text.SimpleDateFormat import java.util.Date -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ -class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging { +class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { import testImplicits._ @@ -38,6 +38,43 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin sqlContext.streams.active.foreach(_.stop()) } + test("EventTimeStats") { + val epsilon = 10E-6 + + val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5) + stats.add(80L) + stats.max should be (100) + stats.min should be (10) + stats.avg should be (30.0 +- epsilon) + stats.count should be (6) + + val stats2 = EventTimeStats(80L, 5L, 15.0, 4) + stats.merge(stats2) + stats.max should be (100) + stats.min should be (5) + stats.avg should be (24.0 +- epsilon) + stats.count should be (10) + } + + test("EventTimeStats: avg on large values") { + val epsilon = 10E-6 + val largeValue = 10000000000L // 10B + // Make sure `largeValue` will cause overflow if we use a Long sum to calc avg. + assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue)) + val stats = + EventTimeStats(max = largeValue, min = largeValue, avg = largeValue, count = largeValue - 1) + stats.add(largeValue) + stats.avg should be (largeValue.toDouble +- epsilon) + + val stats2 = EventTimeStats( + max = largeValue + 1, + min = largeValue, + avg = largeValue + 1, + count = largeValue) + stats.merge(stats2) + stats.avg should be ((largeValue + 0.5) +- epsilon) + } + test("error on bad column") { val inputData = MemoryStream[Int].toDF() val e = intercept[AnalysisException] { -- GitLab