diff --git a/src/examples/Vector.scala b/src/examples/Vector.scala index 0ae2cbc6e8d16cff6df992e965f8deb63cb4a86d..e9fbdca7523a36e8067b1d34e69960b35039d410 100644 --- a/src/examples/Vector.scala +++ b/src/examples/Vector.scala @@ -57,7 +57,7 @@ object Vector { implicit def doubleToMultiplier(num: Double) = new Multiplier(num) implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { - def add(t1: Vector, t2: Vector) = t1 + t2 + def addInPlace(t1: Vector, t2: Vector) = t1 + t2 def zero(initialValue: Vector) = Vector.zeros(initialValue.length) } } diff --git a/src/scala/spark/Accumulators.scala b/src/scala/spark/Accumulators.scala index 3e4cd4935a19cce8599b0dac05046b5fc7966bae..ee93d3c85c0e9e318d8f0a1449a9fe81f4d6da07 100644 --- a/src/scala/spark/Accumulators.scala +++ b/src/scala/spark/Accumulators.scala @@ -4,15 +4,17 @@ import java.io._ import scala.collection.mutable.Map -@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) +@serializable class Accumulator[T]( + @transient initialValue: T, param: AccumulatorParam[T]) { val id = Accumulators.newId - @transient var value_ = initialValue + @transient var value_ = initialValue // Current value on master + val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false Accumulators.register(this) - def += (term: T) { value_ = param.add(value_, term) } + def += (term: T) { value_ = param.addInPlace(value_, term) } def value = this.value_ def value_= (t: T) { if (!deserialized) value_ = t @@ -22,7 +24,7 @@ import scala.collection.mutable.Map // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject - value_ = param.zero(initialValue) + value_ = zero deserialized = true Accumulators.register(this) } @@ -31,7 +33,7 @@ import scala.collection.mutable.Map } @serializable trait AccumulatorParam[T] { - def add(t1: T, t2: T): T + def addInPlace(t1: T, t2: T): T def zero(initialValue: T): T } diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 984a5e56377effdd4b9e38bcf6a6475e00894555..2f1c7431c57a1ac20a40a00ce3fa44788859af79 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -304,6 +304,7 @@ extends ParallelOperation val result = Utils.deserialize[TaskResult[T]](status.getData) results(tidToIndex(tid)) = result.value // Update accumulators + print(" with " + result.accumUpdates.size + " accumulatedUpdates") Accumulators.add(callingThread, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(tidToIndex(tid)) = true diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index d5d4db4678eb4df00df59350286eea879d52e276..c26032bb4f3d7f13a6ac518c50a5c74e807ab506 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -82,11 +82,11 @@ class SparkContext(master: String, frameworkName: String) { object SparkContext { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def add(t1: Double, t2: Double): Double = t1 + t2 + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def add(t1: Int, t2: Int): Int = t1 + t2 + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } // TODO: Add AccumulatorParams for other types, e.g. lists and strings