diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index e63651fcb0d844f6568c1ca05158eca10b187e04..a155adaa8706d7d31bc0acbb6e06b68fc748e9c7 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -11,7 +11,7 @@ class Accumulable[T,R] ( val id = Accumulators.newId @transient - var value_ = initialValue // Current value on master + private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false @@ -30,7 +30,13 @@ class Accumulable[T,R] ( * @param term the other Accumulable that will get merged with this */ def ++= (term: T) { value_ = param.addInPlace(value_, term)} - def value = this.value_ + def value = { + if (!deserialized) value_ + else throw new UnsupportedOperationException("Can't use read value in task") + } + + private[spark] def localValue = value_ + def value_= (t: T) { if (!deserialized) value_ = t else throw new UnsupportedOperationException("Can't use value_= in task") @@ -126,7 +132,7 @@ private object Accumulators { def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { - ret(id) = accum.value + ret(id) = accum.localValue } return ret } diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d9ef8797d6473304628da9c8a9917a5d41f3f3b8..a59b77fc857d12197b469f7d6cefc6601dbf5842 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { } - test ("value readable in tasks") { - import spark.util.Vector - //stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go - - //really easy data - val N = 10000 // Number of data points - val D = 10 // Numer of dimensions - val R = 0.7 // Scaling factor - val ITERATIONS = 5 - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 - val goodX = Vector(D, _ => 0.0001 * rand.nextGaussian() + y) - val noiseX = Vector(D, _ => rand.nextGaussian()) - val x = Vector((goodX.elements.toSeq ++ noiseX.elements.toSeq): _*) - DataPoint(x, y) - } - Array.tabulate(N)(generatePoint) - } - - val data = generateData - for (nThreads <- List(1, 10)) { - //test single & multi-threaded - val sc = new SparkContext("local[" + nThreads + "]", "test") - val weights = Vector.zeros(2*D) - val weightDelta = sc.accumulator(Vector.zeros(2 * D)) - for (itr <- 1 to ITERATIONS) { - val eta = 0.1 / itr - val badErrs = sc.accumulator(0) - sc.parallelize(data).foreach { - p => { - //XXX Note the call to .value here. That is required for this to be an online gradient descent - // instead of a batch version. Should it change to .localValue, and should .value throw an error - // if you try to do this?? - val prod = weightDelta.value.plusDot(weights, p.x) - val trueClassProb = (1 / (1 + exp(-p.y * prod))) // works b/c p(-z) = 1 - p(z) (where p is the logistic function) - val update = p.x * trueClassProb * p.y * eta - //we could also include a momentum term here if our weightDelta accumulator saved a momentum - weightDelta.value += update - if (trueClassProb <= 0.95) - badErrs += 1 - } + test ("value not readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + val thrown = evaluating { + d.foreach { + x => acc.value += x } - println("Iteration " + itr + " had badErrs = " + badErrs.value) - weights += weightDelta.value - println(weights) - //TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ... -// val assertVal = badErrs.value -// assert (assertVal < 100) - } + } should produce [SparkException] + println(thrown) } }