diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index e52d36b7b564b003a4e25667089e86c22fffc495..23245043e24656665e2fbb25932d9ab4c9d40457 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -17,9 +17,6 @@ package org.apache.spark -import org.apache.spark.storage.{BlockId, BlockStatus} - - /** * A simpler value of [[Accumulable]] where the result type being accumulated is the same * as the types of elements being merged, i.e. variables that are only "added" to through an @@ -117,18 +114,4 @@ object AccumulatorParam { def addInPlace(t1: String, t2: String): String = t2 def zero(initialValue: String): String = "" } - - // Note: this is expensive as it makes a copy of the list every time the caller adds an item. - // A better way to use this is to first accumulate the values yourself then them all at once. - @deprecated("use AccumulatorV2", "2.0.0") - private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] { - def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2 - def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T] - } - - // For the internal metric that records what blocks are updated in a particular task - @deprecated("use AccumulatorV2", "2.0.0") - private[spark] object UpdatedBlockStatusesAccumulatorParam - extends ListAccumulatorParam[(BlockId, BlockStatus)] - } diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala index c65108a55eba498ea70df3f34eeea9d661a28e76..a6c64fd680573437fed9ead5a8e597bdf34c0320 100644 --- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala @@ -257,23 +257,66 @@ private[spark] object AccumulatorContext { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * + * @since 2.0.0 + */ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private[this] var _sum = 0L + private[this] var _count = 0L - override def isZero: Boolean = _sum == 0 + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def isZero: Boolean = _count == 0L override def copyAndReset(): LongAccumulator = new LongAccumulator - override def add(v: jl.Long): Unit = _sum += v + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def add(v: jl.Long): Unit = { + _sum += v + _count += 1 + } + + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Long): Unit = { + _sum += v + _count += 1 + } - def add(v: Long): Unit = _sum += v + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ def sum: Long = _sum + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum.toDouble / _count + override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match { - case o: LongAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + case o: LongAccumulator => + _sum += o.sum + _count += o.count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } private[spark] def setValue(newValue: Long): Unit = _sum = newValue @@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision + * floating numbers. + * + * @since 2.0.0 + */ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private[this] var _sum = 0.0 - - override def isZero: Boolean = _sum == 0.0 - - override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator - - override def add(v: jl.Double): Unit = _sum += v - - def add(v: Double): Unit = _sum += v - - def sum: Double = _sum - - override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: DoubleAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - private[spark] def setValue(newValue: Double): Unit = _sum = newValue - - override def localValue: jl.Double = _sum -} - - -class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] { - private[this] var _sum = 0.0 private[this] var _count = 0L - override def isZero: Boolean = _sum == 0.0 && _count == 0 + override def isZero: Boolean = _count == 0L - override def copyAndReset(): AverageAccumulator = new AverageAccumulator + override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ override def add(v: jl.Double): Unit = { _sum += v _count += 1 } - def add(d: Double): Unit = { - _sum += d + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Double): Unit = { + _sum += v _count += 1 } + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ + def sum: Double = _sum + + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum / _count + override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: AverageAccumulator => + case o: DoubleAccumulator => _sum += o.sum _count += o.count - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def localValue: jl.Double = if (_count == 0) { - Double.NaN - } else { - _sum / _count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - def sum: Double = _sum + private[spark] def setValue(newValue: Double): Unit = _sum = newValue - def count: Long = _count + override def localValue: jl.Double = _sum } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 58618b41920af9b0595bc7d310c6a0b90b35a807..e391599336074831fbca28c67f58f7b4e62d88d9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1340,28 +1340,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli acc } - /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator: AverageAccumulator = { - val acc = new AverageAccumulator - register(acc) - acc - } - - /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator(name: String): AverageAccumulator = { - val acc = new AverageAccumulator - register(acc, name) - acc - } - /** * Create and register a list accumulator, which starts with empty list and accumulates inputs * by adding them into the inner list. diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 09eb9c1dbdc62e631260c227820f0a053e2c0353..00200962549e4787e5a2a59fa9c416253a83b295 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -28,7 +28,7 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException -import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam} +import org.apache.spark.AccumulatorParam.StringAccumulatorParam import org.apache.spark.scheduler._ import org.apache.spark.serializer.JavaSerializer @@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex acc.merge("kindness") assert(acc.value === "kindness") } - - test("list accumulator param") { - val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) - assert(acc.value === Seq.empty[Int]) - acc.add(Seq(1, 2)) - assert(acc.value === Seq(1, 2)) - acc += Seq(3, 4) - assert(acc.value === Seq(1, 2, 3, 4)) - acc ++= Seq(5, 6) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) - acc.merge(Seq(7, 8)) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) - acc.setValue(Seq(9, 10)) - assert(acc.value === Seq(9, 10)) - } } private[spark] object AccumulatorSuite { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala new file mode 100644 index 0000000000000000000000000000000000000000..41cdd024922611173c6d8130793253b115834a53 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite} + +class AccumulatorV2Suite extends SparkFunSuite { + + test("LongAccumulator add/avg/sum/count/isZero") { + val acc = new LongAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0) + assert(acc.avg.isNaN) + + acc.add(0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0) + assert(acc.avg == 0.0) + + acc.add(1) + assert(acc.count == 2) + assert(acc.sum == 1) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Long(2)) + assert(acc.count == 3) + assert(acc.sum == 3) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new LongAccumulator + acc2.add(2) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5) + assert(acc.avg == 1.25) + } + + test("DoubleAccumulator add/avg/sum/count/isZero") { + val acc = new DoubleAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0.0) + assert(acc.avg.isNaN) + + acc.add(0.0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0.0) + assert(acc.avg == 0.0) + + acc.add(1.0) + assert(acc.count == 2) + assert(acc.sum == 1.0) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Double(2.0)) + assert(acc.count == 3) + assert(acc.sum == 3.0) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new DoubleAccumulator + acc2.add(2.0) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5.0) + assert(acc.avg == 1.25) + } +}