diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 1893167cf7261ed1b1d4ff0ba3b0ad705d4d28c6..5bb505bf09f17d51702b34f7d86095b590412873 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -291,12 +291,20 @@ private[spark] object TaskMetrics extends Logging { private[spark] class BlockStatusesAccumulator extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] { - private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] + private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] override def isZero(): Boolean = _seq.isEmpty override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator + override def copy(): BlockStatusesAccumulator = { + val newAcc = new BlockStatusesAccumulator + newAcc._seq = _seq.clone() + newAcc + } + + override def reset(): Unit = _seq.clear() + override def add(v: (BlockId, BlockStatus)): Unit = _seq += v override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]]) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index c4879036f6522a2c5e1c97fd30b97c3f58febf1f..0cf9df084fdbe9ecca58629a4c8e2cdc02d035c0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -112,7 +112,22 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy * must return true. */ - def copyAndReset(): AccumulatorV2[IN, OUT] + def copyAndReset(): AccumulatorV2[IN, OUT] = { + val copyAcc = copy() + copyAcc.reset() + copyAcc + } + + /** + * Creates a new copy of this accumulator. + */ + def copy(): AccumulatorV2[IN, OUT] + + /** + * Resets this accumulator, which is zero value. i.e. call `isZero` must + * return true. + */ + def reset(): Unit /** * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator. @@ -137,10 +152,10 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { throw new UnsupportedOperationException( "Accumulator must be registered before send to executor") } - val copy = copyAndReset() - assert(copy.isZero, "copyAndReset must return a zero value copy") - copy.metadata = metadata - copy + val copyAcc = copyAndReset() + assert(copyAcc.isZero, "copyAndReset must return a zero value copy") + copyAcc.metadata = metadata + copyAcc } else { this } @@ -249,8 +264,8 @@ private[spark] object AccumulatorContext { * @since 2.0.0 */ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { - private[this] var _sum = 0L - private[this] var _count = 0L + private var _sum = 0L + private var _count = 0L /** * Adds v to the accumulator, i.e. increment sum by v and count by 1. @@ -258,7 +273,17 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { */ override def isZero: Boolean = _sum == 0L && _count == 0 - override def copyAndReset(): LongAccumulator = new LongAccumulator + override def copy(): LongAccumulator = { + val newAcc = new LongAccumulator + newAcc._count = this._count + newAcc._sum = this._sum + newAcc + } + + override def reset(): Unit = { + _sum = 0L + _count = 0L + } /** * Adds v to the accumulator, i.e. increment sum by v and count by 1. @@ -318,12 +343,22 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { * @since 2.0.0 */ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { - private[this] var _sum = 0.0 - private[this] var _count = 0L + private var _sum = 0.0 + private var _count = 0L override def isZero: Boolean = _sum == 0.0 && _count == 0 - override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator + override def copy(): DoubleAccumulator = { + val newAcc = new DoubleAccumulator + newAcc._count = this._count + newAcc._sum = this._sum + newAcc + } + + override def reset(): Unit = { + _sum = 0.0 + _count = 0L + } /** * Adds v to the accumulator, i.e. increment sum by v and count by 1. @@ -377,12 +412,20 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class ListAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { - private[this] val _list: java.util.List[T] = new java.util.ArrayList[T] + private val _list: java.util.List[T] = new java.util.ArrayList[T] override def isZero: Boolean = _list.isEmpty override def copyAndReset(): ListAccumulator[T] = new ListAccumulator + override def copy(): ListAccumulator[T] = { + val newAcc = new ListAccumulator[T] + newAcc._list.addAll(_list) + newAcc + } + + override def reset(): Unit = _list.clear() + override def add(v: T): Unit = _list.add(v) override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match { @@ -407,12 +450,16 @@ class LegacyAccumulatorWrapper[R, T]( override def isZero: Boolean = _value == param.zero(initialValue) - override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = { + override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) - acc._value = param.zero(initialValue) + acc._value = _value acc } + override def reset(): Unit = { + _value = param.zero(initialValue) + } + override def add(v: T): Unit = _value = param.addAccumulator(_value, v) override def merge(other: AccumulatorV2[T, R]): Unit = other match { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index ecaf4f0c643644b8c2b3f36e06ab84293d3a49ff..439da1306f5aac2d12cdac4d2f7b13f2c5e1fff5 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -116,6 +116,15 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.value.contains(2.0)) assert(!acc.isZero) assert(acc.value.size() === 3) + + val acc3 = acc.copy() + assert(acc3.value.contains(2.0)) + assert(!acc3.isZero) + assert(acc3.value.size() === 3) + + acc3.reset() + assert(acc3.isZero) + assert(acc3.value.isEmpty) } test("LegacyAccumulatorWrapper") { @@ -144,5 +153,13 @@ class AccumulatorV2Suite extends SparkFunSuite { acc.merge(acc2) assert(acc.value === "baz") assert(!acc.isZero) + + val acc3 = acc.copy() + assert(acc3.value === "baz") + assert(!acc3.isZero) + + acc3.reset() + assert(acc3.isZero) + assert(acc3.value === "") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 786110477d8cf348d4ab404b58d080a3bf656678..d6de15494fefaf4ea769d5e69faa363b97276753 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -30,8 +30,15 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. private[this] var _value = initValue + private var _zeroValue = initValue - override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue) + override def copy(): SQLMetric = { + val newAcc = new SQLMetric(metricType, _value) + newAcc._zeroValue = initValue + newAcc + } + + override def reset(): Unit = _value = _zeroValue override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { case o: SQLMetric => _value += o.value @@ -39,7 +46,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == initValue + override def isZero(): Boolean = _value == _zeroValue override def add(v: Long): Unit = _value += v @@ -51,8 +58,6 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER)) } - - def reset(): Unit = _value = initValue }