From 95cd643aa954b7e4229e94fa8bdc99bf3b2bb1da Mon Sep 17 00:00:00 2001 From: Ilya Ganelin <ilya.ganelin@capitalone.com> Date: Sun, 22 Feb 2015 22:43:04 -0800 Subject: [PATCH] [SPARK-3885] Provide mechanism to remove accumulators once they are no longer used Instead of storing a strong reference to accumulators, I've replaced this with a weak reference and updated any code that uses these accumulators to check whether the reference resolves before using the accumulator. A weak reference will be cleared when there is no longer an existing copy of the variable versus using a soft reference in which case accumulators would only be cleared when the GC explicitly ran out of memory. Author: Ilya Ganelin <ilya.ganelin@capitalone.com> Closes #4021 from ilganeli/SPARK-3885 and squashes the following commits: 4ba9575 [Ilya Ganelin] Fixed error in test suite 8510943 [Ilya Ganelin] Extra code bb76ef0 [Ilya Ganelin] File deleted somehow 283a333 [Ilya Ganelin] Added cleanup method for accumulators to remove stale references within Accumulators.original to accumulators that are now out of scope 345fd4f [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-3885 7485a82 [Ilya Ganelin] Fixed build error c8e0f2b [Ilya Ganelin] Added working test for accumulator garbage collection 94ce754 [Ilya Ganelin] Still not being properly garbage collected 8722b63 [Ilya Ganelin] Fixing gc test 7414a9c [Ilya Ganelin] Added test for accumulator garbage collection 18d62ec [Ilya Ganelin] Updated to throw Exception when accessing a GCd accumulator 9a81928 [Ilya Ganelin] Reverting permissions changes 28f705c [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-3885 b820ab4b [Ilya Ganelin] reset d78f4bf [Ilya Ganelin] Removed obsolete comment 0746e61 [Ilya Ganelin] Updated DAGSchedulerSUite to fix bug 3350852 [Ilya Ganelin] Updated DAGScheduler and Suite to correctly use new implementation of WeakRef Accumulator storage c49066a [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-3885 cbb9023 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-3885 a77d11b [Ilya Ganelin] Updated Accumulators class to store weak references instead of strong references to allow garbage collection of old accumulators --- .../scala/org/apache/spark/Accumulators.scala | 36 ++++++++++++++----- .../org/apache/spark/ContextCleaner.scala | 20 +++++++++++ .../scala/org/apache/spark/SparkContext.scala | 28 +++++++++++---- .../apache/spark/scheduler/DAGScheduler.scala | 10 +++++- .../org/apache/spark/AccumulatorSuite.scala | 20 +++++++++++ .../apache/spark/ContextCleanerSuite.scala | 4 +++ .../spark/scheduler/DAGSchedulerSuite.scala | 6 +++- 7 files changed, 107 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f31bfba3f..30f0ccd73c 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -23,6 +23,7 @@ import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -280,10 +281,12 @@ object AccumulatorParam { // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private[spark] object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() + // Store a WeakReference instead of a StrongReference because this way accumulators can be + // appropriately garbage collected during long-running jobs and release memory + type WeakAcc = WeakReference[Accumulable[_, _]] + val originals = Map[Long, WeakAcc]() + val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() { + override protected def initialValue() = Map[Long, WeakAcc]() } var lastId: Long = 0 @@ -294,9 +297,9 @@ private[spark] object Accumulators { def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { if (original) { - originals(a.id) = a + originals(a.id) = new WeakAcc(a) } else { - localAccums.get()(a.id) = a + localAccums.get()(a.id) = new WeakAcc(a) } } @@ -307,11 +310,22 @@ private[spark] object Accumulators { } } + def remove(accId: Long) { + synchronized { + originals.remove(accId) + } + } + // Get the values of the local accumulators for the current thread (by ID) def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + ret(id) = accum.get match { + case Some(values) => values.localValue + case None => None + } } return ret } @@ -320,7 +334,13 @@ private[spark] object Accumulators { def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } } } } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4f..434f1e47cf 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -32,6 +32,7 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanAccum(accId: Long) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -114,6 +115,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanRDD(rdd.id)) } + def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + registerForCleanup(a, CleanAccum(a.id)) + } + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) @@ -145,6 +150,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) } } } catch { @@ -190,6 +197,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } + /** Perform accumulator cleanup. */ + def doCleanupAccum(accId: Long, blocking: Boolean) { + try { + logDebug("Cleaning accumulator " + accId) + Accumulators.remove(accId) + listeners.foreach(_.accumCleaned(accId)) + logInfo("Cleaned accumulator " + accId) + } catch { + case e: Exception => logError("Error cleaning accumulator " + accId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -206,4 +225,5 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) + def accumCleaned(accId: Long) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 85ec5ea113..930d4bea47 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -986,7 +986,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + { + val acc = new Accumulator(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display @@ -994,7 +998,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { - new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -1003,8 +1009,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param) + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the @@ -1013,8 +1022,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param, Some(name)) + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + val acc = new Accumulable(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an accumulator from a "mutable collection" type. @@ -1025,7 +1037,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c58721c2c8..bc84e2351a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -890,8 +890,16 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // In this instance, although the reference in Accumulators.originals is a WeakRef, + // it's guaranteed to exist since the event.accumUpdates Map exists + + val acc = Accumulators.originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => throw new NullPointerException("Non-existent reference to Accumulator") + } + // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index f087fc550d..bd0f8bdefa 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.mutable +import scala.ref.WeakReference import org.scalatest.FunSuite import org.scalatest.Matchers @@ -136,4 +137,23 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { } } + test ("garbage collection") { + // Create an accumulator and let it go out of scope to test that it's properly garbage collected + sc = new SparkContext("local", "test") + var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val accId = acc.id + val ref = WeakReference(acc) + + // Ensure the accumulator is present + assert(ref.get.isDefined) + + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + Accumulators.remove(accId) + assert(!Accumulators.originals.get(accId).isDefined) + } + } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index ae2ae7ed0d..cdfaacee7d 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -382,6 +382,10 @@ class CleanerTester( toBeCleanedBroadcstIds -= broadcastId logInfo("Broadcast" + broadcastId + " cleaned") } + + def accumCleaned(accId: Long): Unit = { + logInfo("Cleaned accId " + accId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9d0c127369..4bf7f9e647 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -735,7 +735,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assert(Accumulators.originals(accum.id).value === 1) + + val accVal = Accumulators.originals(accum.id).get.get.value + + assert(accVal === 1) + assertDataStructuresEmpty } -- GitLab