From b7ce592bece278571ca612fd025cc34f41c3448a Mon Sep 17 00:00:00 2001 From: Justin Ma <jtma@eecs.berkeley.edu> Date: Sat, 25 Sep 2010 14:37:25 -0700 Subject: [PATCH] changes to accumulator to add objects in-place. --- src/examples/Vector.scala | 2 +- src/scala/spark/Accumulators.scala | 12 +++++++----- src/scala/spark/MesosScheduler.scala | 1 + src/scala/spark/SparkContext.scala | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/examples/Vector.scala b/src/examples/Vector.scala index 0ae2cbc6e8..e9fbdca752 100644 --- a/src/examples/Vector.scala +++ b/src/examples/Vector.scala @@ -57,7 +57,7 @@ object Vector { implicit def doubleToMultiplier(num: Double) = new Multiplier(num) implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { - def add(t1: Vector, t2: Vector) = t1 + t2 + def addInPlace(t1: Vector, t2: Vector) = t1 + t2 def zero(initialValue: Vector) = Vector.zeros(initialValue.length) } } diff --git a/src/scala/spark/Accumulators.scala b/src/scala/spark/Accumulators.scala index 3e4cd4935a..ee93d3c85c 100644 --- a/src/scala/spark/Accumulators.scala +++ b/src/scala/spark/Accumulators.scala @@ -4,15 +4,17 @@ import java.io._ import scala.collection.mutable.Map -@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) +@serializable class Accumulator[T]( + @transient initialValue: T, param: AccumulatorParam[T]) { val id = Accumulators.newId - @transient var value_ = initialValue + @transient var value_ = initialValue // Current value on master + val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false Accumulators.register(this) - def += (term: T) { value_ = param.add(value_, term) } + def += (term: T) { value_ = param.addInPlace(value_, term) } def value = this.value_ def value_= (t: T) { if (!deserialized) value_ = t @@ -22,7 +24,7 @@ import scala.collection.mutable.Map // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject - value_ = param.zero(initialValue) + value_ = zero deserialized = true Accumulators.register(this) } @@ -31,7 +33,7 @@ import scala.collection.mutable.Map } @serializable trait AccumulatorParam[T] { - def add(t1: T, t2: T): T + def addInPlace(t1: T, t2: T): T def zero(initialValue: T): T } diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 984a5e5637..2f1c7431c5 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -304,6 +304,7 @@ extends ParallelOperation val result = Utils.deserialize[TaskResult[T]](status.getData) results(tidToIndex(tid)) = result.value // Update accumulators + print(" with " + result.accumUpdates.size + " accumulatedUpdates") Accumulators.add(callingThread, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(tidToIndex(tid)) = true diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index d5d4db4678..c26032bb4f 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -82,11 +82,11 @@ class SparkContext(master: String, frameworkName: String) { object SparkContext { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def add(t1: Double, t2: Double): Double = t1 + t2 + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def add(t1: Int, t2: Int): Int = t1 + t2 + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } // TODO: Add AccumulatorParams for other types, e.g. lists and strings -- GitLab