Skip to content
Snippets Groups Projects
Commit b7ce592b authored by Justin Ma's avatar Justin Ma
Browse files

changes to accumulator to add objects in-place.

parent 366c09c4
No related branches found
No related tags found
No related merge requests found
......@@ -57,7 +57,7 @@ object Vector {
implicit def doubleToMultiplier(num: Double) = new Multiplier(num)
implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] {
def add(t1: Vector, t2: Vector) = t1 + t2
def addInPlace(t1: Vector, t2: Vector) = t1 + t2
def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
}
}
......@@ -4,15 +4,17 @@ import java.io._
import scala.collection.mutable.Map
@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T])
@serializable class Accumulator[T](
@transient initialValue: T, param: AccumulatorParam[T])
{
val id = Accumulators.newId
@transient var value_ = initialValue
@transient var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
Accumulators.register(this)
def += (term: T) { value_ = param.add(value_, term) }
def += (term: T) { value_ = param.addInPlace(value_, term) }
def value = this.value_
def value_= (t: T) {
if (!deserialized) value_ = t
......@@ -22,7 +24,7 @@ import scala.collection.mutable.Map
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject
value_ = param.zero(initialValue)
value_ = zero
deserialized = true
Accumulators.register(this)
}
......@@ -31,7 +33,7 @@ import scala.collection.mutable.Map
}
@serializable trait AccumulatorParam[T] {
def add(t1: T, t2: T): T
def addInPlace(t1: T, t2: T): T
def zero(initialValue: T): T
}
......
......@@ -304,6 +304,7 @@ extends ParallelOperation
val result = Utils.deserialize[TaskResult[T]](status.getData)
results(tidToIndex(tid)) = result.value
// Update accumulators
print(" with " + result.accumUpdates.size + " accumulatedUpdates")
Accumulators.add(callingThread, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(tidToIndex(tid)) = true
......
......@@ -82,11 +82,11 @@ class SparkContext(master: String, frameworkName: String) {
object SparkContext {
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def add(t1: Double, t2: Double): Double = t1 + t2
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def add(t1: Int, t2: Int): Int = t1 + t2
def addInPlace(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment