From b90bf520fd7b979a90d1377cfc2ee7f0bf82c705 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Thu, 27 Apr 2017 19:38:14 -0700
Subject: [PATCH] [SPARK-12837][CORE] Do not send the name of internal
 accumulator to executor side

## What changes were proposed in this pull request?

When sending accumulator updates back to driver, the network overhead is pretty big as there are a lot of accumulators, e.g. `TaskMetrics` will send about 20 accumulators everytime, there may be a lot of `SQLMetric` if the query plan is complicated.

Therefore, it's critical to reduce the size of serialized accumulator. A simple way is to not send the name of internal accumulators to executor side, as it's unnecessary. When executor sends accumulator updates back to driver, we can look up the accumulator name in `AccumulatorContext` easily. Note that, we still need to send names of normal accumulators, as the user code run at executor side may rely on accumulator names.

In the future, we should reimplement `TaskMetrics` to not rely on accumulators and use custom serialization.

Tried on the example in https://issues.apache.org/jira/browse/SPARK-12837, the size of serialized accumulator has been cut down by about 40%.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17596 from cloud-fan/oom.
---
 .../apache/spark/executor/TaskMetrics.scala   | 29 ++++++-------
 .../org/apache/spark/scheduler/Task.scala     | 13 +++---
 .../org/apache/spark/util/AccumulatorV2.scala | 28 +++++++------
 .../spark/scheduler/TaskContextSuite.scala    |  2 +-
 .../ui/jobs/JobProgressListenerSuite.scala    |  2 +-
 .../apache/spark/util/JsonProtocolSuite.scala |  2 +-
 .../SpecificParquetRecordReaderBase.java      | 12 +++---
 .../parquet/ParquetFilterSuite.scala          | 42 +++++++++++++++----
 8 files changed, 76 insertions(+), 54 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index dfd2f818ac..a3ce3d1ccc 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable {
 
   private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
 
-  /**
-   * Looks for a registered accumulator by accumulator name.
-   */
-  private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
-    accumulators.find { acc =>
-      acc.name.isDefined && acc.name.get == name
-    }
+  private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = {
+    // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
+    // value will be updated at driver side.
+    internalAccums.filter(a => !a.isZero || a == _resultSize)
   }
 }
 
@@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging {
    */
   def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
     val tm = new TaskMetrics
-    val (internalAccums, externalAccums) =
-      accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
-
-    internalAccums.foreach { acc =>
-      val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
-      tmAcc.metadata = acc.metadata
-      tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
+    for (acc <- accums) {
+      val name = acc.name
+      if (name.isDefined && tm.nameToAccums.contains(name.get)) {
+        val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
+        tmAcc.metadata = acc.metadata
+        tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
+      } else {
+        tm.externalAccums += acc
+      }
     }
-
-    tm.externalAccums ++= externalAccums
     tm
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7fd2918960..5c337b992c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -182,14 +182,11 @@ private[spark] abstract class Task[T](
    */
   def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = {
     if (context != null) {
-      context.taskMetrics.internalAccums.filter { a =>
-        // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
-        // value will be updated at driver side.
-        // Note: internal accumulators representing task metrics always count failed values
-        !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE)
-      // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter
-      // them out.
-      } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
+      // Note: internal accumulators representing task metrics always count failed values
+      context.taskMetrics.nonZeroInternalAccums() ++
+        // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not
+        // filter them out.
+        context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
     } else {
       Seq.empty
     }
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index 7479de5514..a65ec75cc5 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
    * Returns the name of this accumulator, can only be called after registration.
    */
   final def name: Option[String] = {
-    assertMetadataNotNull()
-    metadata.name
+    if (atDriverSide) {
+      AccumulatorContext.get(id).flatMap(_.metadata.name)
+    } else {
+      assertMetadataNotNull()
+      metadata.name
+    }
   }
 
   /**
@@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
       }
       val copyAcc = copyAndReset()
       assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
-      copyAcc.metadata = metadata
+      val isInternalAcc =
+        (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) ||
+          getClass.getSimpleName == "SQLMetric"
+      if (isInternalAcc) {
+        // Do not serialize the name of internal accumulator and send it to executor.
+        copyAcc.metadata = metadata.copy(name = None)
+      } else {
+        copyAcc.metadata = metadata
+      }
       copyAcc
     } else {
       this
@@ -263,16 +275,6 @@ private[spark] object AccumulatorContext {
     originals.clear()
   }
 
-  /**
-   * Looks for a registered accumulator by accumulator name.
-   */
-  private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
-    originals.values().asScala.find { ref =>
-      val acc = ref.get
-      acc != null && acc.name.isDefined && acc.name.get == name
-    }.map(_.get)
-  }
-
   // Identifier for distinguishing SQL metrics from other accumulators
   private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 8f576daa77..b22da565d8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     sc = new SparkContext("local", "test")
     // Create a dummy task. We won't end up running this; we just want to collect
     // accumulator updates from it.
-    val taskMetrics = TaskMetrics.empty
+    val taskMetrics = TaskMetrics.registered
     val task = new Task[Int](0, 0, 0) {
       context = new TaskContextImpl(0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 93964a2d56..48be3be817 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val execId = "exe-1"
 
     def makeTaskMetrics(base: Int): TaskMetrics = {
-      val taskMetrics = TaskMetrics.empty
+      val taskMetrics = TaskMetrics.registered
       val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
       val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
       val inputMetrics = taskMetrics.inputMetrics
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index a64dbeae47..a77c8e3cab 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
       hasHadoopInput: Boolean,
       hasOutput: Boolean,
       hasRecords: Boolean = true) = {
-    val t = TaskMetrics.empty
+    val t = TaskMetrics.registered
     // Set CPU times same as wall times for testing purpose
     t.setExecutorDeserializeTime(a)
     t.setExecutorDeserializeCpuTime(a)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index eb97118872..0bab321a65 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -153,14 +153,14 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
     }
 
     // For test purpose.
-    // If the predefined accumulator exists, the row group number to read will be updated
-    // to the accumulator. So we can check if the row groups are filtered or not in test case.
+    // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
+    // will be updated to the accumulator. So we can check if the row groups are filtered or not
+    // in test case.
     TaskContext taskContext = TaskContext$.MODULE$.get();
     if (taskContext != null) {
-      Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics()
-        .lookForAccumulatorByName("numRowGroups");
-      if (accu.isDefined()) {
-        ((LongAccumulator)accu.get()).add((long)blocks.size());
+      Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics().externalAccums().lastOption();
+      if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
+        ((AccumulatorV2<Integer, Integer>)accu.get()).add(blocks.size());
       }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 9a3328fcec..dd53b56132 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
 
 /**
  * A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
         val path = s"${dir.getCanonicalPath}/table"
         (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
 
-        Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
-          withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
-            val accu = new LongAccumulator
-            accu.register(sparkContext, Some("numRowGroups"))
+        Seq(true, false).foreach { enablePushDown =>
+          withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) {
+            val accu = new NumRowGroupsAcc
+            sparkContext.register(accu)
 
             val df = spark.read.parquet(path).filter("a < 100")
             df.foreachPartition(_.foreach(v => accu.add(0)))
             df.collect
 
-            val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
-            assert(numRowGroups.isDefined)
-            assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
+            if (enablePushDown) {
+              assert(accu.value == 0)
+            } else {
+              assert(accu.value > 0)
+            }
             AccumulatorContext.remove(accu.id)
           }
         }
@@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
     }
   }
 }
+
+class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
+  private var _sum = 0
+
+  override def isZero: Boolean = _sum == 0
+
+  override def copy(): AccumulatorV2[Integer, Integer] = {
+    val acc = new NumRowGroupsAcc()
+    acc._sum = _sum
+    acc
+  }
+
+  override def reset(): Unit = _sum = 0
+
+  override def add(v: Integer): Unit = _sum += v
+
+  override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match {
+    case a: NumRowGroupsAcc => _sum += a._sum
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  override def value: Integer = _sum
+}
-- 
GitLab