From ebfd91c542aaead343cb154277fcf9114382fee7 Mon Sep 17 00:00:00 2001 From: zsxwing <zsxwing@gmail.com> Date: Fri, 7 Aug 2015 00:09:58 -0700 Subject: [PATCH] [SPARK-9467][SQL]Add SQLMetric to specialize accumulators to avoid boxing This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`. Author: zsxwing <zsxwing@gmail.com> Closes #7996 from zsxwing/sql-accu and squashes the following commits: 14a5f0a [zsxwing] Address comments 367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable 42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing --- .../scala/org/apache/spark/Accumulators.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 15 -- .../spark/sql/execution/SparkPlan.scala | 33 ++-- .../spark/sql/execution/basicOperators.scala | 11 +- .../apache/spark/sql/metric/SQLMetrics.scala | 149 ++++++++++++++++++ .../org/apache/spark/sql/ui/SQLListener.scala | 17 +- .../apache/spark/sql/ui/SparkPlanGraph.scala | 8 +- .../spark/sql/metric/SQLMetricsSuite.scala | 145 +++++++++++++++++ .../spark/sql/ui/SQLListenerSuite.scala | 5 +- 9 files changed, 338 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 462d5c96d4..064246dfa7 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -257,7 +257,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa */ class Accumulator[T] private[spark] ( @transient private[spark] val initialValue: T, - private[spark] val param: AccumulatorParam[T], + param: AccumulatorParam[T], name: Option[String], internal: Boolean) extends Accumulable[T, T](initialValue, param, name, internal) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5662686436..9ced44131b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1238,21 +1238,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli acc } - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display - * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the - * driver can access the accumulator's `value`. The latest local value of such accumulator will be - * sent back to the driver via heartbeats. - * - * @tparam T type that can be added to the accumulator, must be thread safe - */ - private[spark] def internalAccumulator[T](initialValue: T, name: String)( - implicit param: AccumulatorParam[T]): Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name), internal = true) - cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 719ad432e2..1915496d16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected[sql] def trackNumOfRowsEnabled: Boolean = false - private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows") + private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = + if (trackNumOfRowsEnabled) { + Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + } + else { + Map.empty + } /** - * Return all accumulators containing metrics of this SparkPlan. + * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) { - Map("numRows" -> numOfRowsAccumulator) - } else { - Map.empty - } + private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + + /** + * Return a IntSQLMetric according to the name. + */ + private[sql] def intMetric(name: String): IntSQLMetric = + metrics(name).asInstanceOf[IntSQLMetric] /** - * Return the accumulator according to the name. + * Return a LongSQLMetric according to the name. */ - private[sql] def accumulator[T](name: String): Accumulator[T] = - accumulators(name).asInstanceOf[Accumulator[T]] + private[sql] def longMetric(name: String): LongSQLMetric = + metrics(name).asInstanceOf[LongSQLMetric] // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() if (trackNumOfRowsEnabled) { - val numRows = accumulator[Long]("numRows") + val numRows = longMetric("numRows") doExecute().map { row => numRows += 1 row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f4677b4ee8..0680f31d40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator @@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - private[sql] override lazy val accumulators = Map( - "numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"), - "numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows")) + private[sql] override lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numInputRows = accumulator[Long]("numInputRows") - val numOutputRows = accumulator[Long]("numOutputRows") + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala new file mode 100644 index 0000000000..3b907e5da7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala @@ -0,0 +1,149 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + * + * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. + */ +private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( + name: String, val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), true) + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + + def zero: R +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricValue[T] extends Serializable { + + def value: T + + override def toString: String = value.toString +} + +/** + * A wrapper of Long to avoid boxing and unboxing when using Accumulator + */ +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { + + def add(incr: Long): LongSQLMetricValue = { + _value += incr + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Long = _value +} + +/** + * A wrapper of Int to avoid boxing and unboxing when using Accumulator + */ +private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { + + def add(term: Int): IntSQLMetricValue = { + _value += term + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Int = _value +} + +/** + * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class LongSQLMetric private[metric](name: String) + extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { + + override def +=(term: Long): Unit = { + localValue.add(term) + } + + override def add(term: Long): Unit = { + localValue.add(term) + } +} + +/** + * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class IntSQLMetric private[metric](name: String) + extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { + + override def +=(term: Int): Unit = { + localValue.add(term) + } + + override def add(term: Int): Unit = { + localValue.add(term) + } +} + +private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { + + override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) + + override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero + + override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) +} + +private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { + + override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) + + override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero + + override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) +} + +private[sql] object SQLMetrics { + + def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { + val acc = new IntSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } + + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + val acc = new LongSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala index e7b1dd1ffa..2fd4fc658d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala @@ -21,11 +21,12 @@ import scala.collection.mutable import com.google.common.annotations.VisibleForTesting -import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging} +import org.apache.spark.{JobExecutionStatus, Logging} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // Old data in the following fields must be removed in "trimExecutionsIfNecessary". // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data - - // VisibleForTesting private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() /** @@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { accumulatorUpdate } - }.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam) + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) case None => // This execution has been dropped Map.empty @@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = { + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => val param = paramFunc(accumulatorId) - (accumulatorId, values.map(_._2).reduceLeft(param.addInPlace)) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) } } @@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData( private[ui] case class SQLPlanMetric( name: String, accumulatorId: Long, - accumulatorParam: AccumulatorParam[Any]) + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) /** * Store all accumulatorUpdates for all tasks in a Spark stage. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala index 7910c163ba..1ba50b95be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.AccumulatorParam import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph { nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.accumulators.toSeq.map { case (key, accumulator) => - SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id, - accumulator.param.asInstanceOf[AccumulatorParam[Any]]) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala new file mode 100644 index 0000000000..d22160f538 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite { + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("IntSQLMetric should not box Int") { + val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") + val f = () => { l += 1 } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = TestSQLContext.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "<init>" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala index f1fcaf5953..69a561e16a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution @@ -65,9 +66,9 @@ class SQLListenerSuite extends SparkFunSuite { speculative = false ) - private def createTaskMetrics(accumulatorUpdates: Map[Long, Any]): TaskMetrics = { + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates) + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) metrics.updateAccumulators() metrics } -- GitLab