From 358e7bf652d6fedd9377593025cd661c142efeca Mon Sep 17 00:00:00 2001 From: Xiangrui Meng <meng@databricks.com> Date: Thu, 16 Jul 2015 23:02:06 -0700 Subject: [PATCH] [SPARK-9126] [MLLIB] do not assert on time taken by Thread.sleep() Measure lower and upper bounds for task time and use them for validation. This PR also implements `Stopwatch.toString`. This suite should finish in less than 1 second. jkbradley pwendell Author: Xiangrui Meng <meng@databricks.com> Closes #7457 from mengxr/SPARK-9126 and squashes the following commits: 4b40faa [Xiangrui Meng] simplify tests 739f5bd [Xiangrui Meng] do not assert on time taken by Thread.sleep() --- .../apache/spark/ml/util/stopwatches.scala | 4 +- .../apache/spark/ml/util/StopwatchSuite.scala | 64 ++++++++++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 5fdf878a3d..8d4174124b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -67,6 +67,8 @@ private[spark] abstract class Stopwatch extends Serializable { */ def elapsed(): Long + override def toString: String = s"$name: ${elapsed()}ms" + /** * Gets the current time in milliseconds. */ @@ -145,7 +147,7 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext override def toString: String = { stopwatches.values.toArray.sortBy(_.name) - .map(c => s" ${c.name}: ${c.elapsed()}ms") + .map(c => s" $c") .mkString("{\n", ",\n", "\n}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 8df6617fe0..9e6bc7193c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.util +import java.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopwatchSuite._ + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { assert(sw.name === "sw") assert(sw.elapsed() === 0L) @@ -29,18 +33,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { sw.stop() } - sw.start() - Thread.sleep(50) - val duration = sw.stop() - assert(duration >= 50 && duration < 100) // using a loose upper bound + val duration = checkStopwatch(sw) val elapsed = sw.elapsed() assert(elapsed === duration) - sw.start() - Thread.sleep(50) - val duration2 = sw.stop() - assert(duration2 >= 50 && duration2 < 100) + val duration2 = checkStopwatch(sw) val elapsed2 = sw.elapsed() assert(elapsed2 === duration + duration2) + assert(sw.toString === s"sw: ${elapsed2}ms") sw.start() assert(sw.isRunning) intercept[AssertionError] { @@ -61,14 +60,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => - sw.start() - Thread.sleep(50) - sw.stop() + acc += checkStopwatch(sw) } assert(!sw.isRunning) val elapsed = sw.elapsed() - assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + assert(elapsed === acc.value) } test("MultiStopwatch") { @@ -81,29 +79,47 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { sw("some") } assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") - sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("local").stop() - Thread.sleep(50) - sw("spark").stop() + val localDuration = checkStopwatch(sw("local")) + val sparkDuration = checkStopwatch(sw("spark")) val localElapsed = sw("local").elapsed() val sparkElapsed = sw("spark").elapsed() - assert(localElapsed >= 50 && localElapsed < 100) - assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(localElapsed === localDuration) + assert(sparkElapsed === sparkDuration) assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("spark").stop() + val duration = checkStopwatch(sw("spark")) sw("local").stop() + acc += duration } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) val sparkElapsed2 = sw("spark").elapsed() - assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + assert(sparkElapsed2 === sparkElapsed + acc.value) } } + +private object StopwatchSuite extends SparkFunSuite { + + /** + * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and + * returns the duration reported by the stopwatch. + */ + def checkStopwatch(sw: Stopwatch): Long = { + val ubStart = now + sw.start() + val lbStart = now + Thread.sleep(new Random().nextInt(10)) + val lb = now - lbStart + val duration = sw.stop() + val ub = now - ubStart + assert(duration >= lb && duration <= ub) + duration + } + + /** The current time in milliseconds. */ + private def now: Long = System.currentTimeMillis() +} -- GitLab