Skip to content
Snippets Groups Projects
Commit 6f9a18fe authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[HOTFIX][CORE] fix a concurrence issue in NewAccumulator

## What changes were proposed in this pull request?

`AccumulatorContext` is not thread-safe, that's why all of its methods are synchronized. However, there is one exception: the `AccumulatorContext.originals`. `NewAccumulator` use it to check if it's registered, which is wrong as it's not synchronized.

This PR mark `AccumulatorContext.originals` as `private` and now all access to `AccumulatorContext` is synchronized.

## How was this patch tested?

I verified it locally. To be safe, we can let jenkins test it many times to make sure this problem is gone.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12773 from cloud-fan/debug.
parent 9c7c42bc
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,8 @@ import java.io.ObjectInputStream ...@@ -22,6 +22,8 @@ import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import javax.annotation.concurrent.GuardedBy import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -57,7 +59,7 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { ...@@ -57,7 +59,7 @@ abstract class NewAccumulator[IN, OUT] extends Serializable {
* registered before ues, or it will throw exception. * registered before ues, or it will throw exception.
*/ */
final def isRegistered: Boolean = final def isRegistered: Boolean =
metadata != null && AccumulatorContext.originals.containsKey(metadata.id) metadata != null && AccumulatorContext.get(metadata.id).isDefined
private def assertMetadataNotNull(): Unit = { private def assertMetadataNotNull(): Unit = {
if (metadata == null) { if (metadata == null) {
...@@ -197,7 +199,7 @@ private[spark] object AccumulatorContext { ...@@ -197,7 +199,7 @@ private[spark] object AccumulatorContext {
* TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051). * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
*/ */
@GuardedBy("AccumulatorContext") @GuardedBy("AccumulatorContext")
val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]] private val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
private[this] val nextId = new AtomicLong(0L) private[this] val nextId = new AtomicLong(0L)
...@@ -207,6 +209,10 @@ private[spark] object AccumulatorContext { ...@@ -207,6 +209,10 @@ private[spark] object AccumulatorContext {
*/ */
def newId(): Long = nextId.getAndIncrement def newId(): Long = nextId.getAndIncrement
def numAccums: Int = synchronized(originals.size)
def accumIds: Set[Long] = synchronized(originals.keySet().asScala.toSet)
/** /**
* Register an [[Accumulator]] created on the driver such that it can be used on the executors. * Register an [[Accumulator]] created on the driver such that it can be used on the executors.
* *
......
...@@ -191,7 +191,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex ...@@ -191,7 +191,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
assert(ref.get.isEmpty) assert(ref.get.isEmpty)
AccumulatorContext.remove(accId) AccumulatorContext.remove(accId)
assert(!AccumulatorContext.originals.containsKey(accId)) assert(!AccumulatorContext.get(accId).isDefined)
} }
test("get accum") { test("get accum") {
......
...@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { ...@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
private val myCleaner = new SaveAccumContextCleaner(this) private val myCleaner = new SaveAccumContextCleaner(this)
override def cleaner: Option[ContextCleaner] = Some(myCleaner) override def cleaner: Option[ContextCleaner] = Some(myCleaner)
} }
assert(AccumulatorContext.originals.isEmpty) assert(AccumulatorContext.numAccums == 0)
sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
val numInternalAccums = TaskMetrics.empty.internalAccums.length val numInternalAccums = TaskMetrics.empty.internalAccums.length
// We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
assert(AccumulatorContext.originals.size === numInternalAccums * 2) assert(AccumulatorContext.numAccums === numInternalAccums * 2)
val accumsRegistered = sc.cleaner match { val accumsRegistered = sc.cleaner match {
case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
case _ => Seq.empty[Long] case _ => Seq.empty[Long]
} }
// Make sure the same set of accumulators is registered for cleanup // Make sure the same set of accumulators is registered for cleanup
assert(accumsRegistered.size === numInternalAccums * 2) assert(accumsRegistered.size === numInternalAccums * 2)
assert(accumsRegistered.toSet === AccumulatorContext.originals.keySet().asScala) assert(accumsRegistered.toSet === AccumulatorContext.accumIds)
} }
/** /**
......
...@@ -334,10 +334,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext ...@@ -334,10 +334,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sql("SELECT * FROM t2").count() sql("SELECT * FROM t2").count()
AccumulatorContext.synchronized { AccumulatorContext.synchronized {
val accsSize = AccumulatorContext.originals.size val accsSize = AccumulatorContext.numAccums
sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2") sqlContext.uncacheTable("t2")
assert((accsSize - 2) == AccumulatorContext.originals.size) assert((accsSize - 2) == AccumulatorContext.numAccums)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment