From ebfd91c542aaead343cb154277fcf9114382fee7 Mon Sep 17 00:00:00 2001
From: zsxwing <zsxwing@gmail.com>
Date: Fri, 7 Aug 2015 00:09:58 -0700
Subject: [PATCH] [SPARK-9467][SQL]Add SQLMetric to specialize accumulators to
 avoid boxing

This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`.

Author: zsxwing <zsxwing@gmail.com>

Closes #7996 from zsxwing/sql-accu and squashes the following commits:

14a5f0a [zsxwing] Address comments
367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable
42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing
---
 .../scala/org/apache/spark/Accumulators.scala |   2 +-
 .../scala/org/apache/spark/SparkContext.scala |  15 --
 .../spark/sql/execution/SparkPlan.scala       |  33 ++--
 .../spark/sql/execution/basicOperators.scala  |  11 +-
 .../apache/spark/sql/metric/SQLMetrics.scala  | 149 ++++++++++++++++++
 .../org/apache/spark/sql/ui/SQLListener.scala |  17 +-
 .../apache/spark/sql/ui/SparkPlanGraph.scala  |   8 +-
 .../spark/sql/metric/SQLMetricsSuite.scala    | 145 +++++++++++++++++
 .../spark/sql/ui/SQLListenerSuite.scala       |   5 +-
 9 files changed, 338 insertions(+), 47 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala
 create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala

diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 462d5c96d4..064246dfa7 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -257,7 +257,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
  */
 class Accumulator[T] private[spark] (
     @transient private[spark] val initialValue: T,
-    private[spark] val param: AccumulatorParam[T],
+    param: AccumulatorParam[T],
     name: Option[String],
     internal: Boolean)
   extends Accumulable[T, T](initialValue, param, name, internal) {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 5662686436..9ced44131b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1238,21 +1238,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
     acc
   }
 
-  /**
-   * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
-   * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the
-   * driver can access the accumulator's `value`. The latest local value of such accumulator will be
-   * sent back to the driver via heartbeats.
-   *
-   * @tparam T type that can be added to the accumulator, must be thread safe
-   */
-  private[spark] def internalAccumulator[T](initialValue: T, name: String)(
-    implicit param: AccumulatorParam[T]): Accumulator[T] = {
-    val acc = new Accumulator(initialValue, param, Some(name), internal = true)
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
-    acc
-  }
-
   /**
    * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
    * with `+=`. Only the driver can access the accumuable's `value`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 719ad432e2..1915496d16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{Accumulator, Logging}
+import org.apache.spark.Logging
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.{RDD, RDDOperationScope}
 import org.apache.spark.sql.SQLContext
@@ -32,6 +32,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics}
 import org.apache.spark.sql.types.DataType
 
 object SparkPlan {
@@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
    */
   protected[sql] def trackNumOfRowsEnabled: Boolean = false
 
-  private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows")
+  private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] =
+    if (trackNumOfRowsEnabled) {
+      Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
+    }
+    else {
+      Map.empty
+    }
 
   /**
-   * Return all accumulators containing metrics of this SparkPlan.
+   * Return all metrics containing metrics of this SparkPlan.
    */
-  private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) {
-      Map("numRows" -> numOfRowsAccumulator)
-    } else {
-      Map.empty
-    }
+  private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics
+
+  /**
+   * Return a IntSQLMetric according to the name.
+   */
+  private[sql] def intMetric(name: String): IntSQLMetric =
+    metrics(name).asInstanceOf[IntSQLMetric]
 
   /**
-   * Return the accumulator according to the name.
+   * Return a LongSQLMetric according to the name.
    */
-  private[sql] def accumulator[T](name: String): Accumulator[T] =
-    accumulators(name).asInstanceOf[Accumulator[T]]
+  private[sql] def longMetric(name: String): LongSQLMetric =
+    metrics(name).asInstanceOf[LongSQLMetric]
 
   // TODO: Move to `DistributedPlan`
   /** Specifies how data is partitioned across different nodes in the cluster. */
@@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
     RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
       prepare()
       if (trackNumOfRowsEnabled) {
-        val numRows = accumulator[Long]("numRows")
+        val numRows = longMetric("numRows")
         doExecute().map { row =>
           numRows += 1
           row
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f4677b4ee8..0680f31d40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.collection.ExternalSorter
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
@@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
 case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
-  private[sql] override lazy val accumulators = Map(
-    "numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"),
-    "numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows"))
+  private[sql] override lazy val metrics = Map(
+    "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val numInputRows = accumulator[Long]("numInputRows")
-    val numOutputRows = accumulator[Long]("numOutputRows")
+    val numInputRows = longMetric("numInputRows")
+    val numOutputRows = longMetric("numOutputRows")
     child.execute().mapPartitions { iter =>
       val predicate = newPredicate(condition, child.output)
       iter.filter { row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala
new file mode 100644
index 0000000000..3b907e5da7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala
@@ -0,0 +1,149 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.metric
+
+import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ *
+ * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
+ */
+private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
+    name: String, val param: SQLMetricParam[R, T])
+  extends Accumulable[R, T](param.zero, param, Some(name), true)
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
+
+  def zero: R
+}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricValue[T] extends Serializable {
+
+  def value: T
+
+  override def toString: String = value.toString
+}
+
+/**
+ * A wrapper of Long to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
+
+  def add(incr: Long): LongSQLMetricValue = {
+    _value += incr
+    this
+  }
+
+  // Although there is a boxing here, it's fine because it's only called in SQLListener
+  override def value: Long = _value
+}
+
+/**
+ * A wrapper of Int to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
+
+  def add(term: Int): IntSQLMetricValue = {
+    _value += term
+    this
+  }
+
+  // Although there is a boxing here, it's fine because it's only called in SQLListener
+  override def value: Int = _value
+}
+
+/**
+ * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
+ * `+=` and `add`.
+ */
+private[sql] class LongSQLMetric private[metric](name: String)
+  extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
+
+  override def +=(term: Long): Unit = {
+    localValue.add(term)
+  }
+
+  override def add(term: Long): Unit = {
+    localValue.add(term)
+  }
+}
+
+/**
+ * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's
+ * `+=` and `add`.
+ */
+private[sql] class IntSQLMetric private[metric](name: String)
+  extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) {
+
+  override def +=(term: Int): Unit = {
+    localValue.add(term)
+  }
+
+  override def add(term: Int): Unit = {
+    localValue.add(term)
+  }
+}
+
+private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
+
+  override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
+
+  override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
+    r1.add(r2.value)
+
+  override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
+
+  override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
+}
+
+private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] {
+
+  override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t)
+
+  override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue =
+    r1.add(r2.value)
+
+  override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero
+
+  override def zero: IntSQLMetricValue = new IntSQLMetricValue(0)
+}
+
+private[sql] object SQLMetrics {
+
+  def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = {
+    val acc = new IntSQLMetric(name)
+    sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    acc
+  }
+
+  def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
+    val acc = new LongSQLMetric(name)
+    sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    acc
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
index e7b1dd1ffa..2fd4fc658d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
@@ -21,11 +21,12 @@ import scala.collection.mutable
 
 import com.google.common.annotations.VisibleForTesting
 
-import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging}
+import org.apache.spark.{JobExecutionStatus, Logging}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler._
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
 
 private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging {
 
@@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
 
   // Old data in the following fields must be removed in "trimExecutionsIfNecessary".
   // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data
-
-  // VisibleForTesting
   private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]()
 
   /**
@@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
                accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield {
             accumulatorUpdate
           }
-        }.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) }
+        }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
         mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
-          executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam)
+          executionUIData.accumulatorMetrics(accumulatorId).metricParam).
+          mapValues(_.asInstanceOf[SQLMetricValue[_]].value)
       case None =>
         // This execution has been dropped
         Map.empty
@@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
 
   private def mergeAccumulatorUpdates(
       accumulatorUpdates: Seq[(Long, Any)],
-      paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = {
+      paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = {
     accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
       val param = paramFunc(accumulatorId)
-      (accumulatorId, values.map(_._2).reduceLeft(param.addInPlace))
+      (accumulatorId,
+        values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace))
     }
   }
 
@@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData(
 private[ui] case class SQLPlanMetric(
     name: String,
     accumulatorId: Long,
-    accumulatorParam: AccumulatorParam[Any])
+    metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
 
 /**
  * Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
index 7910c163ba..1ba50b95be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
@@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong
 
 import scala.collection.mutable
 
-import org.apache.spark.AccumulatorParam
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
 
 /**
  * A graph used for storing information of an executionPlan of DataFrame.
@@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph {
       nodeIdGenerator: AtomicLong,
       nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
       edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = {
-    val metrics = plan.accumulators.toSeq.map { case (key, accumulator) =>
-      SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id,
-        accumulator.param.asInstanceOf[AccumulatorParam[Any]])
+    val metrics = plan.metrics.toSeq.map { case (key, metric) =>
+      SQLPlanMetric(metric.name.getOrElse(key), metric.id,
+        metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]])
     }
     val node = SparkPlanGraphNode(
       nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala
new file mode 100644
index 0000000000..d22160f538
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala
@@ -0,0 +1,145 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.metric
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.collection.mutable
+
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.util.Utils
+
+
+class SQLMetricsSuite extends SparkFunSuite {
+
+  test("LongSQLMetric should not box Long") {
+    val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
+    val f = () => { l += 1L }
+    BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+      val boxingFinder = new BoxingFinder()
+      cl.accept(boxingFinder, 0)
+      assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}")
+    }
+  }
+
+  test("IntSQLMetric should not box Int") {
+    val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int")
+    val f = () => { l += 1 }
+    BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+      val boxingFinder = new BoxingFinder()
+      cl.accept(boxingFinder, 0)
+      assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}")
+    }
+  }
+
+  test("Normal accumulator should do boxing") {
+    // We need this test to make sure BoxingFinder works.
+    val l = TestSQLContext.sparkContext.accumulator(0L)
+    val f = () => { l += 1L }
+    BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+      val boxingFinder = new BoxingFinder()
+      cl.accept(boxingFinder, 0)
+      assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test")
+    }
+  }
+}
+
+private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)
+
+/**
+ * If `method` is null, search all methods of this class recursively to find if they do some boxing.
+ * If `method` is specified, only search this method of the class to speed up the searching.
+ *
+ * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles.
+ */
+private class BoxingFinder(
+    method: MethodIdentifier[_] = null,
+    val boxingInvokes: mutable.Set[String] = mutable.Set.empty,
+    visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty)
+  extends ClassVisitor(ASM4) {
+
+  private val primitiveBoxingClassName =
+    Set("java/lang/Long",
+      "java/lang/Double",
+      "java/lang/Integer",
+      "java/lang/Float",
+      "java/lang/Short",
+      "java/lang/Character",
+      "java/lang/Byte",
+      "java/lang/Boolean")
+
+  override def visitMethod(
+      access: Int, name: String, desc: String, sig: String, exceptions: Array[String]):
+    MethodVisitor = {
+    if (method != null && (method.name != name || method.desc != desc)) {
+      // If method is specified, skip other methods.
+      return new MethodVisitor(ASM4) {}
+    }
+
+    new MethodVisitor(ASM4) {
+      override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+        if (op == INVOKESPECIAL && name == "<init>" || op == INVOKESTATIC && name == "valueOf") {
+          if (primitiveBoxingClassName.contains(owner)) {
+            // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l)
+            boxingInvokes.add(s"$owner.$name")
+          }
+        } else {
+          // scalastyle:off classforname
+          val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false,
+            Thread.currentThread.getContextClassLoader)
+          // scalastyle:on classforname
+          val m = MethodIdentifier(classOfMethodOwner, name, desc)
+          if (!visitedMethods.contains(m)) {
+            // Keep track of visited methods to avoid potential infinite cycles
+            visitedMethods += m
+            BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl =>
+              visitedMethods += m
+              cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0)
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+private object BoxingFinder {
+
+  def getClassReader(cls: Class[_]): Option[ClassReader] = {
+    val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+    val resourceStream = cls.getResourceAsStream(className)
+    val baos = new ByteArrayOutputStream(128)
+    // Copy data over, before delegating to ClassReader -
+    // else we can run out of open file handles.
+    Utils.copyStream(resourceStream, baos, true)
+    // ASM4 doesn't support Java 8 classes, which requires ASM5.
+    // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes),
+    // then ClassReader will throw IllegalArgumentException,
+    // However, since this is only for testing, it's safe to skip these classes.
+    try {
+      Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray)))
+    } catch {
+      case _: IllegalArgumentException => None
+    }
+  }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
index f1fcaf5953..69a561e16a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
@@ -21,6 +21,7 @@ import java.util.Properties
 
 import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite}
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.sql.metric.LongSQLMetricValue
 import org.apache.spark.scheduler._
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.execution.SQLExecution
@@ -65,9 +66,9 @@ class SQLListenerSuite extends SparkFunSuite {
     speculative = false
   )
 
-  private def createTaskMetrics(accumulatorUpdates: Map[Long, Any]): TaskMetrics = {
+  private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
     val metrics = new TaskMetrics
-    metrics.setAccumulatorsUpdater(() => accumulatorUpdates)
+    metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_)))
     metrics.updateAccumulators()
     metrics
   }
-- 
GitLab