Skip to content
Snippets Groups Projects
Commit f7149c5e authored by Imran Rashid's avatar Imran Rashid Committed by Matei Zaharia
Browse files

tasks cannot access value of accumulator

parent 244cbbe3
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ class Accumulable[T,R] (
val id = Accumulators.newId
@transient
var value_ = initialValue // Current value on master
private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
......@@ -30,7 +30,13 @@ class Accumulable[T,R] (
* @param term the other Accumulable that will get merged with this
*/
def ++= (term: T) { value_ = param.addInPlace(value_, term)}
def value = this.value_
def value = {
if (!deserialized) value_
else throw new UnsupportedOperationException("Can't use read value in task")
}
private[spark] def localValue = value_
def value_= (t: T) {
if (!deserialized) value_ = t
else throw new UnsupportedOperationException("Can't use value_= in task")
......@@ -126,7 +132,7 @@ private object Accumulators {
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
ret(id) = accum.value
ret(id) = accum.localValue
}
return ret
}
......
......@@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
}
test ("value readable in tasks") {
import spark.util.Vector
//stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go
//really easy data
val N = 10000 // Number of data points
val D = 10 // Numer of dimensions
val R = 0.7 // Scaling factor
val ITERATIONS = 5
val rand = new Random(42)
case class DataPoint(x: Vector, y: Double)
def generateData = {
def generatePoint(i: Int) = {
val y = if(i % 2 == 0) -1 else 1
val goodX = Vector(D, _ => 0.0001 * rand.nextGaussian() + y)
val noiseX = Vector(D, _ => rand.nextGaussian())
val x = Vector((goodX.elements.toSeq ++ noiseX.elements.toSeq): _*)
DataPoint(x, y)
}
Array.tabulate(N)(generatePoint)
}
val data = generateData
for (nThreads <- List(1, 10)) {
//test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test")
val weights = Vector.zeros(2*D)
val weightDelta = sc.accumulator(Vector.zeros(2 * D))
for (itr <- 1 to ITERATIONS) {
val eta = 0.1 / itr
val badErrs = sc.accumulator(0)
sc.parallelize(data).foreach {
p => {
//XXX Note the call to .value here. That is required for this to be an online gradient descent
// instead of a batch version. Should it change to .localValue, and should .value throw an error
// if you try to do this??
val prod = weightDelta.value.plusDot(weights, p.x)
val trueClassProb = (1 / (1 + exp(-p.y * prod))) // works b/c p(-z) = 1 - p(z) (where p is the logistic function)
val update = p.x * trueClassProb * p.y * eta
//we could also include a momentum term here if our weightDelta accumulator saved a momentum
weightDelta.value += update
if (trueClassProb <= 0.95)
badErrs += 1
}
test ("value not readable in tasks") {
import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test")
val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val d = sc.parallelize(1 to maxI)
val thrown = evaluating {
d.foreach {
x => acc.value += x
}
println("Iteration " + itr + " had badErrs = " + badErrs.value)
weights += weightDelta.value
println(weights)
//TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ...
// val assertVal = badErrs.value
// assert (assertVal < 100)
}
} should produce [SparkException]
println(thrown)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment