Skip to content
Snippets Groups Projects
Commit 83659af1 authored by Imran Rashid's avatar Imran Rashid Committed by Matei Zaharia
Browse files

Accumulator now inherits from Accumulable, whcih simplifies a bunch of other things (eg., no +:=)

Conflicts:

	core/src/main/scala/spark/Accumulators.scala
parent 79d58ed2
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,9 @@ import java.io._ ...@@ -4,9 +4,9 @@ import java.io._
import scala.collection.mutable.Map import scala.collection.mutable.Map
class Accumulator[T] ( class Accumulable[T,R] (
@transient initialValue: T, @transient initialValue: T,
param: AccumulatorParam[T]) param: AccumulableParam[T,R])
extends Serializable { extends Serializable {
val id = Accumulators.newId val id = Accumulators.newId
...@@ -17,7 +17,19 @@ class Accumulator[T] ( ...@@ -17,7 +17,19 @@ class Accumulator[T] (
Accumulators.register(this, true) Accumulators.register(this, true)
def += (term: T) { value_ = param.addInPlace(value_, term) } /**
* add more data to this accumulator / accumulable
* @param term
*/
def += (term: R) { value_ = param.addToAccum(value_, term) }
/**
* merge two accumulable objects together
* <p>
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term
*/
def ++= (term: T) { value_ = param.addInPlace(value_, term)}
def value = this.value_ def value = this.value_
def value_= (t: T) { def value_= (t: T) {
if (!deserialized) value_ = t if (!deserialized) value_ = t
...@@ -35,48 +47,58 @@ class Accumulator[T] ( ...@@ -35,48 +47,58 @@ class Accumulator[T] (
override def toString = value_.toString override def toString = value_.toString
} }
class Accumulatable[T,Y]( class Accumulator[T](
@transient initialValue: T, @transient initialValue: T,
param: AccumulatableParam[T,Y]) extends Accumulator[T](initialValue, param) { param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param)
/**
* add more data to the current value of the this accumulator, via
* AccumulatableParam.addToAccum
* @param term added to the current value of the accumulator
*/
def +:= (term: Y) {value_ = param.addToAccum(value_, term)}
}
/** /**
* A datatype that can be accumulated, ie. has a commutative & associative + * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
* as the accumulated value
* @tparam T * @tparam T
*/ */
trait AccumulatorParam[T] extends Serializable { trait AccumulatorParam[T] extends AccumulableParam[T,T] {
def addInPlace(t1: T, t2: T): T def addToAccum(t1: T, t2: T) : T = {
def zero(initialValue: T): T addInPlace(t1, t2)
}
} }
/** /**
* A datatype that can be accumulated. Slightly extends [[spark.AccumulatorParam]] to allow you to * A datatype that can be accumulated, ie. has a commutative & associative +.
* combine a different data type with value so far * <p>
* You must define how to add data, and how to merge two of these together. For some datatypes, these might be
* the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't
* always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you
* will union two sets together.
*
* @tparam T the full accumulated data * @tparam T the full accumulated data
* @tparam Y partial data that can be added in * @tparam R partial data that can be added in
*/ */
trait AccumulatableParam[T,Y] extends AccumulatorParam[T] { trait AccumulableParam[T,R] extends Serializable {
/** /**
* Add additional data to the accumulator value. * Add additional data to the accumulator value.
* @param t1 the current value of the accumulator * @param t1 the current value of the accumulator
* @param t2 the data to be added to the accumulator * @param t2 the data to be added to the accumulator
* @return the new value of the accumulator * @return the new value of the accumulator
*/ */
def addToAccum(t1: T, t2: Y) : T def addToAccum(t1: T, t2: R) : T
/**
* merge two accumulated values together
* @param t1
* @param t2
* @return
*/
def addInPlace(t1: T, t2: T): T
def zero(initialValue: T): T
} }
// TODO: The multi-thread support in accumulators is kind of lame; check // TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right // if there's a more intuitive way of doing it right
private object Accumulators { private object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then // TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulator[_]]() val originals = Map[Long, Accumulable[_,_]]()
val localAccums = Map[Thread, Map[Long, Accumulator[_]]]() val localAccums = Map[Thread, Map[Long, Accumulable[_,_]]]()
var lastId: Long = 0 var lastId: Long = 0
def newId: Long = synchronized { def newId: Long = synchronized {
...@@ -84,14 +106,12 @@ private object Accumulators { ...@@ -84,14 +106,12 @@ private object Accumulators {
return lastId return lastId
} }
def register(a: Accumulator[_], original: Boolean) { def register(a: Accumulable[_,_], original: Boolean): Unit = synchronized {
synchronized { if (original) {
if (original) { originals(a.id) = a
originals(a.id) = a } else {
} else { val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map()) accums(a.id) = a
accums(a.id) = a
}
} }
} }
...@@ -112,12 +132,10 @@ private object Accumulators { ...@@ -112,12 +132,10 @@ private object Accumulators {
} }
// Add values to the original accumulators with some given IDs // Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]) { def add(values: Map[Long, Any]): Unit = synchronized {
synchronized { for ((id, value) <- values) {
for ((id, value) <- values) { if (originals.contains(id)) {
if (originals.contains(id)) { originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
originals(id).asInstanceOf[Accumulator[Any]] += value
}
} }
} }
} }
......
...@@ -286,15 +286,15 @@ class SparkContext( ...@@ -286,15 +286,15 @@ class SparkContext(
new Accumulator(initialValue, param) new Accumulator(initialValue, param)
/** /**
* create an accumulatable shared variable, with a `+:=` method * create an accumulatable shared variable, with a `+=` method
* @param initialValue * @param initialValue
* @param param * @param param
* @tparam T accumulator type * @tparam T accumulator type
* @tparam Y type that can be added to the accumulator * @tparam R type that can be added to the accumulator
* @return * @return
*/ */
def accumulatable[T,Y](initialValue: T)(implicit param: AccumulatableParam[T,Y]) = def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulatable(initialValue, param) new Accumulable(initialValue, param)
// Keep around a weak hash map of values to Cached versions? // Keep around a weak hash map of values to Cached versions?
......
...@@ -34,10 +34,10 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { ...@@ -34,10 +34,10 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
val maxI = 1000 val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded for (nThreads <- List(1, 10)) { //test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test") val sc = new SparkContext("local[" + nThreads + "]", "test")
val acc: Accumulatable[mutable.Set[Any], Any] = sc.accumulatable(new mutable.HashSet[Any]()) val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val d = sc.parallelize(1 to maxI) val d = sc.parallelize(1 to maxI)
d.foreach { d.foreach {
x => acc +:= x //note the use of +:= here x => acc += x
} }
val v = acc.value.asInstanceOf[mutable.Set[Int]] val v = acc.value.asInstanceOf[mutable.Set[Int]]
for (i <- 1 to maxI) { for (i <- 1 to maxI) {
...@@ -48,7 +48,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { ...@@ -48,7 +48,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
} }
implicit object SetAccum extends AccumulatableParam[mutable.Set[Any], Any] { implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
t1 ++= t2 t1 ++= t2
t1 t1
...@@ -115,8 +115,8 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { ...@@ -115,8 +115,8 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
weights += weightDelta.value weights += weightDelta.value
println(weights) println(weights)
//TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ... //TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ...
val assertVal = badErrs.value // val assertVal = badErrs.value
assert (assertVal < 100) // assert (assertVal < 100)
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment