diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f31bfba3f8d62a039d60c226f00ff192935d617..30f0ccd73ccca6a5013fc8462e24225b861dac67 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -23,6 +23,7 @@ import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -280,10 +281,12 @@ object AccumulatorParam { // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private[spark] object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() + // Store a WeakReference instead of a StrongReference because this way accumulators can be + // appropriately garbage collected during long-running jobs and release memory + type WeakAcc = WeakReference[Accumulable[_, _]] + val originals = Map[Long, WeakAcc]() + val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() { + override protected def initialValue() = Map[Long, WeakAcc]() } var lastId: Long = 0 @@ -294,9 +297,9 @@ private[spark] object Accumulators { def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { if (original) { - originals(a.id) = a + originals(a.id) = new WeakAcc(a) } else { - localAccums.get()(a.id) = a + localAccums.get()(a.id) = new WeakAcc(a) } } @@ -307,11 +310,22 @@ private[spark] object Accumulators { } } + def remove(accId: Long) { + synchronized { + originals.remove(accId) + } + } + // Get the values of the local accumulators for the current thread (by ID) def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + ret(id) = accum.get match { + case Some(values) => values.localValue + case None => None + } } return ret } @@ -320,7 +334,13 @@ private[spark] object Accumulators { def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } } } } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc5c8a3658720dde35394c65c12c1c..434f1e47cf822c9bc45c2c34d0a6a14b01d3fc2b 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -32,6 +32,7 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanAccum(accId: Long) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -114,6 +115,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanRDD(rdd.id)) } + def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + registerForCleanup(a, CleanAccum(a.id)) + } + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) @@ -145,6 +150,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) } } } catch { @@ -190,6 +197,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } + /** Perform accumulator cleanup. */ + def doCleanupAccum(accId: Long, blocking: Boolean) { + try { + logDebug("Cleaning accumulator " + accId) + Accumulators.remove(accId) + listeners.foreach(_.accumCleaned(accId)) + logInfo("Cleaned accumulator " + accId) + } catch { + case e: Exception => logError("Error cleaning accumulator " + accId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -206,4 +225,5 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) + def accumCleaned(accId: Long) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 85ec5ea11357eb112a59d631de146c76d23e006e..930d4bea4785bc21aedc8338c8fee810dee761b2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -986,7 +986,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + { + val acc = new Accumulator(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display @@ -994,7 +998,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { - new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -1003,8 +1009,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param) + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the @@ -1013,8 +1022,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param, Some(name)) + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an accumulator from a "mutable collection" type. @@ -1025,7 +1037,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c58721c2c82b7f9c4d6e51d2d27bf02a5fd234ed..bc84e2351ad74e8bf9bd90cc87d71893ca311f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -890,8 +890,16 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // In this instance, although the reference in Accumulators.originals is a WeakRef, + // it's guaranteed to exist since the event.accumUpdates Map exists + + val acc = Accumulators.originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => throw new NullPointerException("Non-existent reference to Accumulator") + } + // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index f087fc550dde3f3338f252494e301ef15dd4c202..bd0f8bdefa171e3de0149877ed668684a719b337 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.mutable +import scala.ref.WeakReference import org.scalatest.FunSuite import org.scalatest.Matchers @@ -136,4 +137,23 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { } } + test ("garbage collection") { + // Create an accumulator and let it go out of scope to test that it's properly garbage collected + sc = new SparkContext("local", "test") + var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val accId = acc.id + val ref = WeakReference(acc) + + // Ensure the accumulator is present + assert(ref.get.isDefined) + + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + Accumulators.remove(accId) + assert(!Accumulators.originals.get(accId).isDefined) + } + } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index ae2ae7ed0d3aaad9bc94038fc0937c4e1f7399ee..cdfaacee7da409b2a87a56cf8bf91411b6efdf87 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -382,6 +382,10 @@ class CleanerTester( toBeCleanedBroadcstIds -= broadcastId logInfo("Broadcast" + broadcastId + " cleaned") } + + def accumCleaned(accId: Long): Unit = { + logInfo("Cleaned accId " + accId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9d0c1273695f66d0e69b1d34bb35cd724919ef85..4bf7f9e647d555fc93b9183f2181ff4488e28bee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -735,7 +735,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assert(Accumulators.originals(accum.id).value === 1) + + val accVal = Accumulators.originals(accum.id).get.get.value + + assert(accVal === 1) + assertDataStructuresEmpty }