diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4d49fe51598508a559cc15baed42eebf36d798ca..8acd0439b690046489020c540b359f72870d8c82 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -34,6 +34,14 @@ import org.apache.spark.serializer.JavaSerializer class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { import AccumulatorParam._ + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index c426bb7a4e80999b060ea80b4b548ec2b4fd9d14..474550608ba2f21b55ffd0105fa2402f12b65c88 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -28,6 +28,14 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { import InternalAccumulator._ import AccumulatorParam._ + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } + test("get param") { assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index d3359c7406e45ca07d2d70890543416cd8ba8af0..99366a32c4e163856415184d1691ef8451897876 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -18,14 +18,26 @@ package org.apache.spark // scalastyle:off -import org.scalatest.{FunSuite, Outcome} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} /** * Base abstract class for all unit tests in Spark for handling common functionality. */ -private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +private[spark] abstract class SparkFunSuite + extends FunSuite + with BeforeAndAfterAll + with Logging { // scalastyle:on + protected override def afterAll(): Unit = { + try { + // Avoid leaking map entries in tests that use accumulators without SparkContext + Accumulators.clear() + } finally { + super.afterAll() + } + } + /** * Log the suite name and the test name before and after each test. * @@ -42,8 +54,6 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { test() } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") - // Avoid leaking map entries in tests that use accumulators without SparkContext - Accumulators.clear() } }