diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 5fdf878a3df72d7acdd8c45e8befc01db0abb25b..8d4174124b5c4f83b236b9517ab9b94145c7d9c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -67,6 +67,8 @@ private[spark] abstract class Stopwatch extends Serializable { */ def elapsed(): Long + override def toString: String = s"$name: ${elapsed()}ms" + /** * Gets the current time in milliseconds. */ @@ -145,7 +147,7 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext override def toString: String = { stopwatches.values.toArray.sortBy(_.name) - .map(c => s" ${c.name}: ${c.elapsed()}ms") + .map(c => s" $c") .mkString("{\n", ",\n", "\n}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 8df6617fe0228366e5082bbce17c43079429b212..9e6bc7193c13bd35140b5f530162d89e419cee8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.util +import java.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopwatchSuite._ + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { assert(sw.name === "sw") assert(sw.elapsed() === 0L) @@ -29,18 +33,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { sw.stop() } - sw.start() - Thread.sleep(50) - val duration = sw.stop() - assert(duration >= 50 && duration < 100) // using a loose upper bound + val duration = checkStopwatch(sw) val elapsed = sw.elapsed() assert(elapsed === duration) - sw.start() - Thread.sleep(50) - val duration2 = sw.stop() - assert(duration2 >= 50 && duration2 < 100) + val duration2 = checkStopwatch(sw) val elapsed2 = sw.elapsed() assert(elapsed2 === duration + duration2) + assert(sw.toString === s"sw: ${elapsed2}ms") sw.start() assert(sw.isRunning) intercept[AssertionError] { @@ -61,14 +60,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => - sw.start() - Thread.sleep(50) - sw.stop() + acc += checkStopwatch(sw) } assert(!sw.isRunning) val elapsed = sw.elapsed() - assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + assert(elapsed === acc.value) } test("MultiStopwatch") { @@ -81,29 +79,47 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { sw("some") } assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") - sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("local").stop() - Thread.sleep(50) - sw("spark").stop() + val localDuration = checkStopwatch(sw("local")) + val sparkDuration = checkStopwatch(sw("spark")) val localElapsed = sw("local").elapsed() val sparkElapsed = sw("spark").elapsed() - assert(localElapsed >= 50 && localElapsed < 100) - assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(localElapsed === localDuration) + assert(sparkElapsed === sparkDuration) assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) rdd.foreach { i => sw("local").start() - sw("spark").start() - Thread.sleep(50) - sw("spark").stop() + val duration = checkStopwatch(sw("spark")) sw("local").stop() + acc += duration } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) val sparkElapsed2 = sw("spark").elapsed() - assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + assert(sparkElapsed2 === sparkElapsed + acc.value) } } + +private object StopwatchSuite extends SparkFunSuite { + + /** + * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and + * returns the duration reported by the stopwatch. + */ + def checkStopwatch(sw: Stopwatch): Long = { + val ubStart = now + sw.start() + val lbStart = now + Thread.sleep(new Random().nextInt(10)) + val lb = now - lbStart + val duration = sw.stop() + val ub = now - ubStart + assert(duration >= lb && duration <= ub) + duration + } + + /** The current time in milliseconds. */ + private def now: Long = System.currentTimeMillis() +}