From 6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Tue, 11 Sep 2012 16:00:06 -0700
Subject: [PATCH] Manually merge pull request #175 by Imran Rashid

---
 core/src/main/scala/spark/Accumulators.scala  | 24 +++++++++++++++++
 core/src/main/scala/spark/SparkContext.scala  | 11 ++++++++
 .../test/scala/spark/AccumulatorSuite.scala   | 27 +++++++++++++++++--
 3 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index d764ffc29d..c157cc8feb 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -3,6 +3,7 @@ package spark
 import java.io._
 
 import scala.collection.mutable.Map
+import scala.collection.generic.Growable
 
 /**
  * A datatype that can be accumulated, i.e. has an commutative and associative +.
@@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable {
   def zero(initialValue: R): R
 }
 
+class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+  extends AccumulableParam[R,T] {
+
+  def addAccumulator(growable: R, elem: T) : R = {
+    growable += elem
+    growable
+  }
+
+  def addInPlace(t1: R, t2: R) : R = {
+    t1 ++= t2
+    t1
+  }
+
+  def zero(initialValue: R): R = {
+    // We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
+    // Instead we'll serialize it to a buffer and load it back.
+    val ser = (new spark.JavaSerializer).newInstance()
+    val copy = ser.deserialize[R](ser.serialize(initialValue))
+    copy.clear()   // In case it contained stuff
+    copy
+  }
+}
+
 /**
  * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
  * as the types of elements being merged.
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 5d0f2950d6..0dec44979f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -7,6 +7,7 @@ import akka.actor.Actor
 import akka.actor.Actor._
 
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.generic.Growable
 
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.conf.Configuration
@@ -307,6 +308,16 @@ class SparkContext(
   def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
     new Accumulable(initialValue, param)
 
+  /**
+   * Create an accumulator from a "mutable collection" type.
+   * 
+   * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
+   * standard mutable collections. So you can use this with mutable Map, Set, etc.
+   */
+  def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
+    val param = new GrowableAccumulableParam[R,T]
+    new Accumulable(initialValue, param)
+  }
 
   // Keep around a weak hash map of values to Cached versions?
   def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d55969c261..71df5941e5 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -56,7 +56,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
     }
   }
 
-
   implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
     def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
       t1 ++= t2
@@ -71,7 +70,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
     }
   }
 
-
   test ("value not readable in tasks") {
     import SetAccum._
     val maxI = 1000
@@ -89,4 +87,29 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
     }
   }
 
+  test ("collection accumulators") {
+    val maxI = 1000
+    for (nThreads <- List(1, 10)) {
+      // test single & multi-threaded
+      val sc = new SparkContext("local[" + nThreads + "]", "test")
+      val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
+      val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
+      val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]())
+      val d = sc.parallelize((1 to maxI) ++ (1 to maxI))
+      d.foreach {
+        x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
+      }
+
+      // Note that this is typed correctly -- no casts necessary
+      setAcc.value.size should be (maxI)
+      bufferAcc.value.size should be (2 * maxI)
+      mapAcc.value.size should be (maxI)
+      for (i <- 1 to maxI) {
+        setAcc.value should contain(i)
+        bufferAcc.value should contain(i)
+        mapAcc.value should contain (i -> i.toString)
+      }
+      sc.stop()
+    }
+  }
 }
-- 
GitLab