Skip to content
Snippets Groups Projects
Commit f7149c5e authored by Imran Rashid's avatar Imran Rashid Committed by Matei Zaharia
Browse files

tasks cannot access value of accumulator

parent 244cbbe3
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,7 @@ class Accumulable[T,R] ( ...@@ -11,7 +11,7 @@ class Accumulable[T,R] (
val id = Accumulators.newId val id = Accumulators.newId
@transient @transient
var value_ = initialValue // Current value on master private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false var deserialized = false
...@@ -30,7 +30,13 @@ class Accumulable[T,R] ( ...@@ -30,7 +30,13 @@ class Accumulable[T,R] (
* @param term the other Accumulable that will get merged with this * @param term the other Accumulable that will get merged with this
*/ */
def ++= (term: T) { value_ = param.addInPlace(value_, term)} def ++= (term: T) { value_ = param.addInPlace(value_, term)}
def value = this.value_ def value = {
if (!deserialized) value_
else throw new UnsupportedOperationException("Can't use read value in task")
}
private[spark] def localValue = value_
def value_= (t: T) { def value_= (t: T) {
if (!deserialized) value_ = t if (!deserialized) value_ = t
else throw new UnsupportedOperationException("Can't use value_= in task") else throw new UnsupportedOperationException("Can't use value_= in task")
...@@ -126,7 +132,7 @@ private object Accumulators { ...@@ -126,7 +132,7 @@ private object Accumulators {
def values: Map[Long, Any] = synchronized { def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]() val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
ret(id) = accum.value ret(id) = accum.localValue
} }
return ret return ret
} }
......
...@@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { ...@@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
} }
test ("value readable in tasks") { test ("value not readable in tasks") {
import spark.util.Vector import SetAccum._
//stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
//really easy data val sc = new SparkContext("local[" + nThreads + "]", "test")
val N = 10000 // Number of data points val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val D = 10 // Numer of dimensions val d = sc.parallelize(1 to maxI)
val R = 0.7 // Scaling factor val thrown = evaluating {
val ITERATIONS = 5 d.foreach {
val rand = new Random(42) x => acc.value += x
case class DataPoint(x: Vector, y: Double)
def generateData = {
def generatePoint(i: Int) = {
val y = if(i % 2 == 0) -1 else 1
val goodX = Vector(D, _ => 0.0001 * rand.nextGaussian() + y)
val noiseX = Vector(D, _ => rand.nextGaussian())
val x = Vector((goodX.elements.toSeq ++ noiseX.elements.toSeq): _*)
DataPoint(x, y)
}
Array.tabulate(N)(generatePoint)
}
val data = generateData
for (nThreads <- List(1, 10)) {
//test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test")
val weights = Vector.zeros(2*D)
val weightDelta = sc.accumulator(Vector.zeros(2 * D))
for (itr <- 1 to ITERATIONS) {
val eta = 0.1 / itr
val badErrs = sc.accumulator(0)
sc.parallelize(data).foreach {
p => {
//XXX Note the call to .value here. That is required for this to be an online gradient descent
// instead of a batch version. Should it change to .localValue, and should .value throw an error
// if you try to do this??
val prod = weightDelta.value.plusDot(weights, p.x)
val trueClassProb = (1 / (1 + exp(-p.y * prod))) // works b/c p(-z) = 1 - p(z) (where p is the logistic function)
val update = p.x * trueClassProb * p.y * eta
//we could also include a momentum term here if our weightDelta accumulator saved a momentum
weightDelta.value += update
if (trueClassProb <= 0.95)
badErrs += 1
}
} }
println("Iteration " + itr + " had badErrs = " + badErrs.value) } should produce [SparkException]
weights += weightDelta.value println(thrown)
println(weights)
//TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ...
// val assertVal = badErrs.value
// assert (assertVal < 100)
}
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment