diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
index e8b8bb32280f394e3e240fe701675529b7876ebc..3bedf89c42ee1ee6781be745ee2d85bcad718c71 100644
--- a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._
 import org.apache.spark.rdd._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
-
+import org.apache.spark.graph.impl.AggregationMsg
 
 /**
  * The `VertexSetIndex` maintains the per-partition mapping from
@@ -659,6 +659,43 @@ object VertexSetRDD {
     apply(rdd,index, (v:V) => v, reduceFunc, reduceFunc)
 
 
+  def aggregate[V: ClassManifest](
+    rdd: RDD[AggregationMsg[V]], index: VertexSetIndex,
+    reduceFunc: (V, V) => V): VertexSetRDD[V] = {
+
+    val cReduceFunc = index.rdd.context.clean(reduceFunc)
+    assert(rdd.partitioner == index.rdd.partitioner)
+    // Use the index to build the new values table
+    val values: RDD[ (Array[V], BitSet) ] = index.rdd.zipPartitions(rdd)( (indexIter, tblIter) => {
+      // There is only one map
+      val index = indexIter.next()
+      assert(!indexIter.hasNext)
+      val values = new Array[V](index.capacity)
+      val bs = new BitSet(index.capacity)
+      for (msg <- tblIter) {
+        // Get the location of the key in the index
+        val pos = index.getPos(msg.vid)
+        if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+          throw new SparkException("Error: Trying to bind an external index " +
+            "to an RDD which contains keys that are not in the index.")
+        } else {
+          // Get the actual index
+          val ind = pos & OpenHashSet.POSITION_MASK
+          // If this value has already been seen then merge
+          if (bs.get(ind)) {
+            values(ind) = cReduceFunc(values(ind), msg.data)
+          } else { // otherwise just store the new value
+            bs.set(ind)
+            values(ind) = msg.data
+          }
+        }
+      }
+      Iterator((values, bs))
+    })
+    new VertexSetRDD(index, values)
+  }
+
+
   /**
    * Construct a vertex set from an RDD using an existing index and a
    * user defined `combiner` to merge duplicate vertices.
@@ -675,11 +712,11 @@ object VertexSetRDD {
    *
    */
   def apply[V: ClassManifest, C: ClassManifest](
-    rdd: RDD[(Vid,V)],
-    index: VertexSetIndex,
-    createCombiner: V => C,
-    mergeValue: (C, V) => C,
-    mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
+      rdd: RDD[(Vid,V)],
+      index: VertexSetIndex,
+      createCombiner: V => C,
+      mergeValue: (C, V) => C,
+      mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
     val cCreateCombiner = index.rdd.context.clean(createCombiner)
     val cMergeValue = index.rdd.context.clean(mergeValue)
     val cMergeCombiners = index.rdd.context.clean(mergeCombiners)
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 9ce06eb9e8b2cb481857f53a838bbd6eb9ddaada..0d7546b57594cdbfe27c4dd2ab95e5e9c15569f6 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -5,14 +5,13 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-
 import org.apache.spark.SparkContext._
 import org.apache.spark.HashPartitioner
 import org.apache.spark.util.ClosureCleaner
 
 import org.apache.spark.graph._
 import org.apache.spark.graph.impl.GraphImpl._
-import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._
+import org.apache.spark.graph.impl.MsgRDDFunctions._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
@@ -73,8 +72,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   def this() = this(null, null, null, null)
 
-
-
   /**
    * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
    * vertex data after it is replicated. Within each partition, it holds a map
@@ -87,22 +84,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] =
     createVTableReplicated(vTable, vid2pid, localVidMap)
 
-
   /** Return a RDD of vertices. */
   @transient override val vertices = vTable
 
-
   /** Return a RDD of edges. */
   @transient override val edges: RDD[Edge[ED]] = {
     eTable.mapPartitions( iter => iter.next()._2.iterator , true )
   }
 
-
   /** Return a RDD that brings edges with its source and destination vertices together. */
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues, eTable)
 
-
   override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
     eTable.persist(newLevel)
     vid2pid.persist(newLevel)
@@ -129,7 +122,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       "Min Load" -> minLoad, "Max Load" -> maxLoad)
   }
 
-
   /**
    * Display the lineage information for this graph.
    */
@@ -187,14 +179,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     println(visited)
   } // end of print lineage
 
-
   override def reverse: Graph[VD, ED] = {
     val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) },
       preservesPartitioning = true)
     new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
   }
 
-
   override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
     val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
     new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
@@ -206,11 +196,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
   override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
     GraphImpl.mapTriplets(this, f)
 
-
   override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
     vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = {
 
@@ -250,7 +238,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
   }
 
-
   override def groupEdgeTriplets[ED2: ClassManifest](
     f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
       val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
@@ -275,7 +262,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
   override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
     Graph[VD,ED2] = {
 
@@ -293,8 +279,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
-
   //////////////////////////////////////////////////////////////////////////////////////////////////
   // Lower level transformation methods
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -305,7 +289,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     : VertexSetRDD[A] =
     GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
 
-
   override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
     (updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2)
     : Graph[VD2, ED] = {
@@ -313,15 +296,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     val newVTable = vTable.leftJoin(updates)(updateF)
     new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
   }
-
-
 } // end of class GraphImpl
 
 
-
-
-
-
 object GraphImpl {
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -331,7 +308,6 @@ object GraphImpl {
     apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a)
   }
 
-
   def apply[VD: ClassManifest, ED: ClassManifest](
     vertices: RDD[(Vid, VD)],
     edges: RDD[Edge[ED]],
@@ -357,7 +333,6 @@ object GraphImpl {
     new GraphImpl(vtable, vid2pid, localVidMap, etable)
   }
 
-
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
    * normal edges data structure (RDD[(Vid, Vid, ED)]).
@@ -379,7 +354,7 @@ object GraphImpl {
       //val part: Pid = canonicalEdgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt)
 
       // Should we be using 3-tuple or an optimized class
-      MessageToPartition(part, (e.srcId, e.dstId, e.attr))
+      new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
     }
     .partitionBy(new HashPartitioner(numPartitions))
     .mapPartitionsWithIndex( (pid, iter) => {
@@ -393,7 +368,6 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-
   protected def createVid2Pid[ED: ClassManifest](
     eTable: RDD[(Pid, EdgePartition[ED])],
     vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
@@ -410,7 +384,6 @@ object GraphImpl {
       .mapValues(a => a.toArray).cache()
   }
 
-
   protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
     RDD[(Pid, VertexIdToIndexMap)] = {
     eTable.mapPartitions( _.map{ case (pid, epart) =>
@@ -423,7 +396,6 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-
   protected def createVTableReplicated[VD: ClassManifest](
       vTable: VertexSetRDD[VD],
       vid2pid: VertexSetRDD[Array[Pid]],
@@ -432,7 +404,10 @@ object GraphImpl {
     // Join vid2pid and vTable, generate a shuffle dependency on the joined
     // result, and get the shuffle id so we can use it on the slave.
     val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
-      pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) }
+      // TODO(rxin): reuse VertexBroadcastMessage
+      pids.iterator.map { pid =>
+        new VertexBroadcastMsg[VD](pid, vid, vdata)
+      }
     }.partitionBy(replicationMap.partitioner.get).cache()
 
     replicationMap.zipPartitions(msgsByPartition){
@@ -442,8 +417,8 @@ object GraphImpl {
       // Populate the vertex array using the vidToIndex map
       val vertexArray = new Array[VD](vidToIndex.capacity)
       for (msg <- msgsIter) {
-        val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK
-        vertexArray(ind) = msg.data._2
+        val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
+        vertexArray(ind) = msg.data
       }
       Iterator((pid, vertexArray))
     }.cache()
@@ -451,7 +426,6 @@ object GraphImpl {
     // @todo assert edge table has partitioner
   }
 
-
   def makeTriplets[VD: ClassManifest, ED: ClassManifest](
     localVidMap: RDD[(Pid, VertexIdToIndexMap)],
     vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
@@ -465,7 +439,6 @@ object GraphImpl {
     }
   }
 
-
   def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
     g: GraphImpl[VD, ED],
     f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
@@ -487,7 +460,6 @@ object GraphImpl {
     new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
   }
 
-
   def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
     g: GraphImpl[VD, ED],
     mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
@@ -499,33 +471,35 @@ object GraphImpl {
     // Map and preaggregate
     val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){
       (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
-      val (pid, edgePartition) = edgePartitionIter.next()
+      val (_, edgePartition) = edgePartitionIter.next()
       val (_, vidToIndex) = vidToIndexIter.next()
       val (_, vertexArray) = vertexArrayIter.next()
       assert(!edgePartitionIter.hasNext)
       assert(!vidToIndexIter.hasNext)
       assert(!vertexArrayIter.hasNext)
       assert(vidToIndex.capacity == vertexArray.size)
+      // Reuse the vidToIndex map to run aggregation.
       val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
-      // We can reuse the vidToIndex map for aggregation here as well.
-      /** @todo Since this has the downside of not allowing "messages" to arbitrary
-       * vertices we should consider just using a fresh map.
-       */
+      // TODO(jegonzal): This doesn't allow users to send messages to arbitrary vertices.
       val msgArray = new Array[A](vertexArray.size)
       val msgBS = new BitSet(vertexArray.size)
       // Iterate over the partition
       val et = new EdgeTriplet[VD, ED]
-      edgePartition.foreach{e =>
+
+      edgePartition.foreach { e =>
         et.set(e)
         et.srcAttr = vmap(e.srcId)
         et.dstAttr = vmap(e.dstId)
-        mapFunc(et).foreach{ case (vid, msg) =>
+        // TODO(rxin): rewrite the foreach using a simple while loop to speed things up.
+        // Also given we are only allowing zero, one, or two messages, we can completely unroll
+        // the for loop.
+        mapFunc(et).foreach { case (vid, msg) =>
           // verify that the vid is valid
           assert(vid == et.srcId || vid == et.dstId)
           // Get the index of the key
           val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
           // Populate the aggregator map
-          if(msgBS.get(ind)) {
+          if (msgBS.get(ind)) {
             msgArray(ind) = reduceFunc(msgArray(ind), msg)
           } else {
             msgArray(ind) = msg
@@ -534,20 +508,19 @@ object GraphImpl {
         }
       }
       // construct an iterator of tuples Iterator[(Vid, A)]
-      msgBS.iterator.map( ind => (vidToIndex.getValue(ind), msgArray(ind)) )
+      msgBS.iterator.map { ind =>
+        new AggregationMsg[A](vidToIndex.getValue(ind), msgArray(ind))
+      }
     }.partitionBy(g.vTable.index.rdd.partitioner.get)
     // do the final reduction reusing the index map
-    VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
+    VertexSetRDD.aggregate(preAgg, g.vTable.index, reduceFunc)
   }
 
-
   protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val mixingPrime: Vid = 1125899906842597L
     (math.abs(src) * mixingPrime).toInt % numParts
   }
 
-
-
   /**
    * This function implements a classic 2D-Partitioning of a sparse matrix.
    * Suppose we have a graph with 11 vertices that we want to partition
@@ -600,7 +573,6 @@ object GraphImpl {
     (col * ceilSqrtNumParts + row) % numParts
   }
 
-
   /**
    * Assign edges to an aribtrary machine corresponding to a
    * random vertex cut.
@@ -609,7 +581,6 @@ object GraphImpl {
     math.abs((src, dst).hashCode()) % numParts
   }
 
-
   /**
    * @todo This will only partition edges to the upper diagonal
    * of the 2D processor space.
@@ -626,4 +597,3 @@ object GraphImpl {
   }
 
 } // end of object GraphImpl
-
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
index b7bbf257a4a5692163d5e53c07473ca35d3dab9b..3fc0b7c0f7588d3b2aff84868367edc7f45dcd39 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
@@ -1,10 +1,35 @@
 package org.apache.spark.graph.impl
 
 import org.apache.spark.Partitioner
-import org.apache.spark.graph.Pid
+import org.apache.spark.graph.{Pid, Vid}
 import org.apache.spark.rdd.{ShuffledRDD, RDD}
 
 
+class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
+    @transient var partition: Pid,
+    var vid: Vid,
+    var data: T)
+  extends Product2[Pid, (Vid, T)] {
+
+  override def _1 = partition
+
+  override def _2 = (vid, data)
+
+  override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
+}
+
+
+class AggregationMsg[@specialized(Int, Long, Double, Boolean) T](var vid: Vid, var data: T)
+  extends Product2[Vid, T] {
+
+  override def _1 = vid
+
+  override def _2 = data
+
+  override def canEqual(that: Any): Boolean = that.isInstanceOf[AggregationMsg[_]]
+}
+
+
 /**
  * A message used to send a specific value to a partition.
  * @param partition index of the target partition.
@@ -22,15 +47,38 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef
   override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
 }
 
-/**
- * Companion object for MessageToPartition.
- */
-object MessageToPartition {
-  def apply[T](partition: Pid, value: T) = new MessageToPartition(partition, value)
+
+class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcastMsg[T]]) {
+  def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
+    val rdd = new ShuffledRDD[Pid, (Vid, T), VertexBroadcastMsg[T]](self, partitioner)
+
+    // Set a custom serializer if the data is of int or double type.
+    if (classManifest[T] == ClassManifest.Int) {
+      rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Double) {
+      rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
+    }
+    rdd
+  }
 }
 
 
-class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
+class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[T]]) {
+  def partitionBy(partitioner: Partitioner): RDD[AggregationMsg[T]] = {
+    val rdd = new ShuffledRDD[Vid, T, AggregationMsg[T]](self, partitioner)
+
+    // Set a custom serializer if the data is of int or double type.
+    if (classManifest[T] == ClassManifest.Int) {
+      rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Double) {
+      rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
+    }
+    rdd
+  }
+}
+
+
+class MsgRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
 
   /**
    * Return a copy of the RDD partitioned using the specified partitioner.
@@ -42,8 +90,16 @@ class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartit
 }
 
 
-object MessageToPartitionRDDFunctions {
+object MsgRDDFunctions {
   implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = {
-    new MessageToPartitionRDDFunctions(rdd)
+    new MsgRDDFunctions(rdd)
+  }
+
+  implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexBroadcastMsg[T]]) = {
+    new VertexBroadcastMsgRDDFunctions(rdd)
+  }
+
+  implicit def rdd2aggMessageRDDFunctions[T: ClassManifest](rdd: RDD[AggregationMsg[T]]) = {
+    new AggregationMessageRDDFunctions(rdd)
   }
 }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8b4c0868b1a6dff011cab21938f60fc1632c6ece
--- /dev/null
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
@@ -0,0 +1,169 @@
+package org.apache.spark.graph.impl
+
+import java.io.{InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer}
+
+
+/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
+class IntVertexBroadcastMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
+        writeLong(msg.vid)
+        writeInt(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        new VertexBroadcastMsg[Int](0, readLong(), readInt()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+
+/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
+class DoubleVertexBroadcastMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
+        writeLong(msg.vid)
+        writeDouble(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      def readObject[T](): T = {
+        new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+
+/** A special shuffle serializer for AggregationMessage[Int]. */
+class IntAggMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[AggregationMsg[Int]]
+        writeLong(msg.vid)
+        writeInt(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+
+/** A special shuffle serializer for AggregationMessage[Double]. */
+class DoubleAggMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[AggregationMsg[Double]]
+        writeLong(msg.vid)
+        writeDouble(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      def readObject[T](): T = {
+        new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Helper classes to shorten the implementation of those special serializers.
+////////////////////////////////////////////////////////////////////////////////
+
+sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
+  // The implementation should override this one.
+  def writeObject[T](t: T): SerializationStream
+
+  def writeInt(v: Int) {
+    s.write(v >> 24)
+    s.write(v >> 16)
+    s.write(v >> 8)
+    s.write(v)
+  }
+
+  def writeLong(v: Long) {
+    s.write((v >>> 56).toInt)
+    s.write((v >>> 48).toInt)
+    s.write((v >>> 40).toInt)
+    s.write((v >>> 32).toInt)
+    s.write((v >>> 24).toInt)
+    s.write((v >>> 16).toInt)
+    s.write((v >>> 8).toInt)
+    s.write(v.toInt)
+  }
+
+  def writeDouble(v: Double) {
+    writeLong(java.lang.Double.doubleToLongBits(v))
+  }
+
+  override def flush(): Unit = s.flush()
+
+  override def close(): Unit = s.close()
+}
+
+
+sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
+  // The implementation should override this one.
+  def readObject[T](): T
+
+  def readInt(): Int = {
+    (s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
+  }
+
+  def readLong(): Long = {
+    (s.read().toLong << 56) |
+      (s.read() & 0xFF).toLong << 48 |
+      (s.read() & 0xFF).toLong << 40 |
+      (s.read() & 0xFF).toLong << 32 |
+      (s.read() & 0xFF).toLong << 24 |
+      (s.read() & 0xFF) << 16 |
+      (s.read() & 0xFF) << 8 |
+      (s.read() & 0xFF)
+  }
+
+  def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
+
+  override def close(): Unit = s.close()
+}
+
+
+sealed trait ShuffleSerializerInstance extends SerializerInstance {
+
+  override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
+
+  override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
+
+  override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
+    throw new UnsupportedOperationException
+
+  // The implementation should override the following two.
+  override def serializeStream(s: OutputStream): SerializationStream
+  override def deserializeStream(s: InputStream): DeserializationStream
+}
diff --git a/graphx-shell b/graphx-shell
new file mode 100755
index 0000000000000000000000000000000000000000..4dd6c68ace888d3996b0ee578057eb922507d38a
--- /dev/null
+++ b/graphx-shell
@@ -0,0 +1,124 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Shell script for starting the Spark Shell REPL
+# Note that it will set MASTER to spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}
+# if those two env vars are set in spark-env.sh but MASTER is not.
+# Options:
+#    -c <cores>    Set the number of cores for REPL to use
+#
+
+# Enter posix mode for bash
+set -o posix
+
+
+# Update the the banner logo
+export SPARK_BANNER_TEXT="Welcome to 
+        ______                 __   _  __ 
+       / ____/________ _____  / /_ | |/ / 
+      / / __/ ___/ __ \`/ __ \/ __ \|   / 
+     / /_/ / /  / /_/ / /_/ / / / /   |   
+     \____/_/   \__,_/ .___/_/ /_/_/|_|   
+                    /_/  Alpha Release               
+
+Powered by: 
+       ____              __
+      / __/__  ___ _____/ /__
+     _\ \/ _ \/ _ \`/ __/  '_/
+    /___/ .__/\_,_/_/ /_/\_\   
+       /_/ version 0.9.0
+
+Example:
+
+ scala> val graph = GraphLoader.textFile(sc, \"hdfs://links\")
+ scala> graph.numVertices
+ scala> graph.numEdges
+ scala> val pageRankGraph = Analytics.pagerank(graph, 10) // 10 iterations
+ scala> val maxPr = pageRankGraph.vertices.map{ case (vid, pr) => pr }.max
+ scala> println(maxPr)
+
+"
+
+export SPARK_SHELL_INIT_BLOCK="import org.apache.spark.graph._;"
+
+# Set the serializer to use Kryo for graphx objects
+SPARK_JAVA_OPTS+=" -Dspark.serializer=org.apache.spark.serializer.KryoSerializer "
+SPARK_JAVA_OPTS+="-Dspark.kryo.registrator=org.apache.spark.graph.GraphKryoRegistrator  "
+SPARK_JAVA_OPTS+="-Dspark.kryoserializer.buffer.mb=10 "
+
+
+
+FWDIR="`dirname $0`"
+
+for o in "$@"; do
+  if [ "$1" = "-c" -o "$1" = "--cores" ]; then
+    shift
+    if [ -n "$1" ]; then
+      OPTIONS="-Dspark.cores.max=$1"
+      shift
+    fi
+  fi
+done
+
+# Set MASTER from spark-env if possible
+if [ -z "$MASTER" ]; then
+  if [ -e "$FWDIR/conf/spark-env.sh" ]; then
+    . "$FWDIR/conf/spark-env.sh"
+  fi
+  if [[ "x" != "x$SPARK_MASTER_IP" && "y" != "y$SPARK_MASTER_PORT" ]]; then
+    MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}"
+    export MASTER
+  fi
+fi
+
+# Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in
+# binary distribution of Spark where Scala is not installed
+exit_status=127
+saved_stty=""
+
+# restore stty settings (echo in particular)
+function restoreSttySettings() {
+  stty $saved_stty
+  saved_stty=""
+}
+
+function onExit() {
+  if [[ "$saved_stty" != "" ]]; then
+    restoreSttySettings
+  fi
+  exit $exit_status
+}
+
+# to reenable echo if we are interrupted before completing.
+trap onExit INT
+
+# save terminal settings
+saved_stty=$(stty -g 2>/dev/null)
+# clear on error so we don't later try to restore them
+if [[ ! $? ]]; then
+  saved_stty=""
+fi
+
+$FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@"
+
+# record the exit status lest it be overwritten:
+# then reenable echo and propagate the code.
+exit_status=$?
+onExit
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 0ced284da68f50bc24a4305dd43668268f7f09a5..efdd90c47f7c84ce5225c64f9e20e5997fa64d83 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -45,7 +45,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master))
   def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None)
   def this() = this(None, new PrintWriter(Console.out, true), None)
-  
+
   var in: InteractiveReader = _   // the input stream from which commands come
   var settings: Settings = _
   var intp: SparkIMain = _
@@ -56,16 +56,16 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     Power[g.type](this, g)
   }
   */
-  
+
   // TODO
   // object opt extends AestheticSettings
-  // 
+  //
   @deprecated("Use `intp` instead.", "2.9.0")
   def interpreter = intp
-  
+
   @deprecated("Use `intp` instead.", "2.9.0")
   def interpreter_= (i: SparkIMain): Unit = intp = i
-  
+
   def history = in.history
 
   /** The context class loader at the time this object was created */
@@ -75,7 +75,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   private val signallable =
     /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand())
     else*/ null
-    
+
   // classpath entries added via :cp
   var addedClasspath: String = ""
 
@@ -87,10 +87,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
 
   /** Record a command for replay should the user request a :replay */
   def addReplay(cmd: String) = replayCommandStack ::= cmd
-  
+
   /** Try to install sigint handler: ignore failure.  Signal handler
    *  will interrupt current line execution if any is in progress.
-   * 
+   *
    *  Attempting to protect the repl from accidental exit, we only honor
    *  a single ctrl-C if the current buffer is empty: otherwise we look
    *  for a second one within a short time.
@@ -124,7 +124,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       Thread.currentThread.setContextClassLoader(originalClassLoader)
     }
   }
-  
+
   class SparkILoopInterpreter extends SparkIMain(settings, out) {
     override lazy val formatting = new Formatting {
       def prompt = SparkILoop.this.prompt
@@ -135,7 +135,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
           |// She's gone rogue, captain! Have to take her out!
           |// Calling Thread.stop on runaway %s with offending code:
           |// scala> %s""".stripMargin
-        
+
         echo(template.format(line.thread, line.code))
         // XXX no way to suppress the deprecation warning
         line.thread.stop()
@@ -151,7 +151,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   def createInterpreter() {
     if (addedClasspath != "")
       settings.classpath append addedClasspath
-      
+
     intp = new SparkILoopInterpreter
     intp.setContextClassLoader()
     installSigIntHandler()
@@ -168,10 +168,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   private def helpSummary() = {
     val usageWidth  = commands map (_.usageMsg.length) max
     val formatStr   = "%-" + usageWidth + "s %s %s"
-    
+
     echo("All commands can be abbreviated, e.g. :he instead of :help.")
     echo("Those marked with a * have more detailed help, e.g. :help imports.\n")
-    
+
     commands foreach { cmd =>
       val star = if (cmd.hasLongHelp) "*" else " "
       echo(formatStr.format(cmd.usageMsg, star, cmd.help))
@@ -182,7 +182,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       case Nil  => echo(cmd + ": no such command.  Type :help for help.")
       case xs   => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
     }
-    Result(true, None)    
+    Result(true, None)
   }
   private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
   private def uniqueCommand(cmd: String): Option[LoopCommand] = {
@@ -193,31 +193,35 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       case xs       => xs find (_.name == cmd)
     }
   }
-  
+
   /** Print a welcome message */
   def printWelcome() {
-    echo("""Welcome to
-      ____              __  
+    val prop = System.getenv("SPARK_BANNER_TEXT")
+    val bannerText =
+      if (prop != null) prop else
+        """Welcome to
+      ____              __
      / __/__  ___ _____/ /__
     _\ \/ _ \/ _ `/ __/  '_/
    /___/ .__/\_,_/_/ /_/\_\   version 0.9.0-SNAPSHOT
-      /_/                  
-""")
+      /_/
+        """
+    echo(bannerText)
     import Properties._
     val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
-      versionString, javaVmName, javaVersion) 
+      versionString, javaVmName, javaVersion)
     echo(welcomeMsg)
   }
-  
+
   /** Show the history */
   lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
     override def usage = "[num]"
     def defaultLines = 20
-    
+
     def apply(line: String): Result = {
       if (history eq NoHistory)
         return "No history available."
-      
+
       val xs      = words(line)
       val current = history.index
       val count   = try xs.head.toInt catch { case _: Exception => defaultLines }
@@ -237,21 +241,21 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     out print msg
     out.flush()
   }
-  
+
   /** Search the history */
   def searchHistory(_cmdline: String) {
     val cmdline = _cmdline.toLowerCase
     val offset  = history.index - history.size + 1
-    
+
     for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
       echo("%d %s".format(index + offset, line))
   }
-  
+
   private var currentPrompt = Properties.shellPromptString
   def setPrompt(prompt: String) = currentPrompt = prompt
   /** Prompt to print when awaiting input */
   def prompt = currentPrompt
-  
+
   import LoopCommand.{ cmd, nullary }
 
   /** Standard commands **/
@@ -273,7 +277,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     nullary("silent", "disable/enable automatic printing of results", verbosity),
     cmd("type", "<expr>", "display the type of an expression without evaluating it", typeCommand)
   )
-  
+
   /** Power user commands */
   lazy val powerCommands: List[LoopCommand] = List(
     //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand),
@@ -298,10 +302,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       |An argument of clear will remove the wrapper if any is active.
       |Note that wrappers do not compose (a new one replaces the old
       |one) and also that the :phase command uses the same machinery,
-      |so setting :wrap will clear any :phase setting.       
+      |so setting :wrap will clear any :phase setting.
     """.stripMargin.trim)
   )
-  
+
   /*
   private def dumpCommand(): Result = {
     echo("" + power)
@@ -309,7 +313,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     in.redrawLine()
   }
   */
-  
+
   private val typeTransforms = List(
     "scala.collection.immutable." -> "immutable.",
     "scala.collection.mutable."   -> "mutable.",
@@ -317,7 +321,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     "java.lang."                  -> "jl.",
     "scala.runtime."              -> "runtime."
   )
-  
+
   private def importsCommand(line: String): Result = {
     val tokens    = words(line)
     val handlers  = intp.languageWildcardHandlers ++ intp.importHandlers
@@ -333,7 +337,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         val implicitMsg    = if (imps.isEmpty) "" else imps.size + " are implicit"
         val foundMsg       = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
         val statsMsg       = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
-        
+
         intp.reporter.printMessage("%2d) %-30s %s%s".format(
           idx + 1,
           handler.importString,
@@ -342,12 +346,12 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         ))
     }
   }
-  
+
   private def implicitsCommand(line: String): Result = {
     val intp = SparkILoop.this.intp
     import intp._
     import global.Symbol
-    
+
     def p(x: Any) = intp.reporter.printMessage("" + x)
 
     // If an argument is given, only show a source with that
@@ -360,14 +364,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
           else (args exists (source.name.toString contains _))
         }
     }
-    
+
     if (filtered.isEmpty)
       return "No implicits have been imported other than those in Predef."
-      
+
     filtered foreach {
       case (source, syms) =>
         p("/* " + syms.size + " implicit members imported from " + source.fullName + " */")
-        
+
         // This groups the members by where the symbol is defined
         val byOwner = syms groupBy (_.owner)
         val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) }
@@ -388,10 +392,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
 
               xss map (xs => xs sortBy (_.name.toString))
             }
-          
-            val ownerMessage = if (owner == source) " defined in " else " inherited from "            
+
+            val ownerMessage = if (owner == source) " defined in " else " inherited from "
             p("  /* " + members.size + ownerMessage + owner.fullName + " */")
-            
+
             memberGroups foreach { group =>
               group foreach (s => p("  " + intp.symbolDefString(s)))
               p("")
@@ -400,7 +404,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         p("")
     }
   }
-  
+
   protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) {
     override def tryClass(path: String): Array[Byte] = {
       // Look for Foo first, then Foo$, but if Foo$ is given explicitly,
@@ -417,20 +421,20 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   private lazy val javap =
     try newJavap()
     catch { case _: Exception => null }
-  
+
   private def typeCommand(line: String): Result = {
     intp.typeOfExpression(line) match {
       case Some(tp) => tp.toString
       case _        => "Failed to determine type."
     }
   }
-  
+
   private def javapCommand(line: String): Result = {
     if (javap == null)
       return ":javap unavailable on this platform."
     if (line == "")
       return ":javap [-lcsvp] [path1 path2 ...]"
-    
+
     javap(words(line)) foreach { res =>
       if (res.isError) return "Failed: " + res.value
       else res.show()
@@ -504,25 +508,25 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     }
     else {
       val what = phased.parse(name)
-      if (what.isEmpty || !phased.set(what)) 
+      if (what.isEmpty || !phased.set(what))
         "'" + name + "' does not appear to represent a valid phase."
       else {
         intp.setExecutionWrapper(pathToPhaseWrapper)
         val activeMessage =
           if (what.toString.length == name.length) "" + what
           else "%s (%s)".format(what, name)
-        
+
         "Active phase is now: " + activeMessage
       }
     }
   }
   */
-  
+
   /** Available commands */
   def commands: List[LoopCommand] = standardCommands /* ++ (
     if (isReplPower) powerCommands else Nil
   )*/
-  
+
   val replayQuestionMessage =
     """|The repl compiler has crashed spectacularly. Shall I replay your
        |session? I can re-run all lines except the last one.
@@ -579,10 +583,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   }
 
   /** interpret all lines from a specified file */
-  def interpretAllFrom(file: File) {    
+  def interpretAllFrom(file: File) {
     val oldIn = in
     val oldReplay = replayCommandStack
-    
+
     try file applyReader { reader =>
       in = SimpleReader(reader, out, false)
       echo("Loading " + file + "...")
@@ -604,26 +608,26 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       echo("")
     }
   }
-  
+
   /** fork a shell and run a command */
   lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
     override def usage = "<command line>"
     def apply(line: String): Result = line match {
       case ""   => showUsage()
-      case _    => 
+      case _    =>
         val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")"
         intp interpret toRun
         ()
     }
   }
-  
+
   def withFile(filename: String)(action: File => Unit) {
     val f = File(filename)
-    
+
     if (f.exists) action(f)
     else echo("That file does not exist")
   }
-  
+
   def loadCommand(arg: String) = {
     var shouldReplay: Option[String] = None
     withFile(arg)(f => {
@@ -657,7 +661,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     }
     else echo("The path '" + f + "' doesn't seem to exist.")
   }
-  
+
   def powerCmd(): Result = {
     if (isReplPower) "Already in power mode."
     else enablePowerMode()
@@ -667,13 +671,13 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     //power.unleash()
     //echo(power.banner)
   }
-  
+
   def verbosity() = {
     val old = intp.printResults
     intp.printResults = !old
     echo("Switched " + (if (old) "off" else "on") + " result printing.")
   }
-  
+
   /** Run one command submitted by the user.  Two values are returned:
     * (1) whether to keep running, (2) the line to record for replay,
     * if any. */
@@ -688,11 +692,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     else if (intp.global == null) Result(false, None)  // Notice failure to create compiler
     else Result(true, interpretStartingWith(line))
   }
-  
+
   private def readWhile(cond: String => Boolean) = {
     Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
   }
-  
+
   def pasteCommand(): Result = {
     echo("// Entering paste mode (ctrl-D to finish)\n")
     val code = readWhile(_ => true) mkString "\n"
@@ -700,17 +704,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     intp interpret code
     ()
   }
-    
+
   private object paste extends Pasted {
     val ContinueString = "     | "
     val PromptString   = "scala> "
-    
+
     def interpret(line: String): Unit = {
       echo(line.trim)
       intp interpret line
       echo("")
     }
-    
+
     def transcript(start: String) = {
       // Printing this message doesn't work very well because it's buried in the
       // transcript they just pasted.  Todo: a short timer goes off when
@@ -731,7 +735,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
   def interpretStartingWith(code: String): Option[String] = {
     // signal completion non-completion input has been received
     in.completion.resetVerbosity()
-    
+
     def reallyInterpret = {
       val reallyResult = intp.interpret(code)
       (reallyResult, reallyResult match {
@@ -741,7 +745,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
           if (in.interactive && code.endsWith("\n\n")) {
             echo("You typed two blank lines.  Starting a new command.")
             None
-          } 
+          }
           else in.readLine(ContinueString) match {
             case null =>
               // we know compilation is going to fail since we're at EOF and the
@@ -755,10 +759,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
           }
       })
     }
-    
+
     /** Here we place ourselves between the user and the interpreter and examine
      *  the input they are ostensibly submitting.  We intervene in several cases:
-     * 
+     *
      *  1) If the line starts with "scala> " it is assumed to be an interpreter paste.
      *  2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
      *     on the previous result.
@@ -787,7 +791,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         val (code, result) = reallyInterpret
         //if (power != null && code == IR.Error)
         //  runCompletion
-        
+
         result
       }
       else runCompletion match {
@@ -808,7 +812,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
       }
     case _ =>
   }
-  
+
   /** Tries to create a JLineReader, falling back to SimpleReader:
    *  unless settings or properties are such that it should start
    *  with SimpleReader.
@@ -837,6 +841,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         org.apache.spark.repl.Main.interp.out.flush();
         """)
       command("import org.apache.spark.SparkContext._")
+      val prop = System.getenv("SPARK_SHELL_INIT_BLOCK")
+      if (prop != null) {
+        command(prop)
+      }
     }
     echo("Type in expressions to have them evaluated.")
     echo("Type :help for more information.")
@@ -884,7 +892,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
 
     this.settings = settings
     createInterpreter()
-    
+
     // sets in to some kind of reader depending on environmental cues
     in = in0 match {
       case Some(reader) => SimpleReader(reader, out, true)
@@ -895,10 +903,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     // it is broken on startup; go ahead and exit
     if (intp.reporter.hasErrors)
       return false
-    
-    try {      
+
+    try {
       // this is about the illusion of snappiness.  We call initialize()
-      // which spins off a separate thread, then print the prompt and try 
+      // which spins off a separate thread, then print the prompt and try
       // our best to look ready.  Ideally the user will spend a
       // couple seconds saying "wow, it starts so fast!" and by the time
       // they type a command the compiler is ready to roll.
@@ -920,19 +928,19 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
     def neededHelp(): String =
       (if (command.settings.help.value) command.usageMsg + "\n" else "") +
       (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
-    
+
     // if they asked for no help and command is valid, we call the real main
     neededHelp() match {
       case ""     => command.ok && process(command.settings)
       case help   => echoNoNL(help) ; true
     }
   }
-  
+
   @deprecated("Use `process` instead", "2.9.0")
   def main(args: Array[String]): Unit = {
     if (isReplDebug)
       System.out.println(new java.util.Date)
-    
+
     process(args)
   }
   @deprecated("Use `process` instead", "2.9.0")
@@ -948,7 +956,7 @@ object SparkILoop {
   // like if you'd just typed it into the repl.
   def runForTranscript(code: String, settings: Settings): String = {
     import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-    
+
     stringFromStream { ostream =>
       Console.withOut(ostream) {
         val output = new PrintWriter(new OutputStreamWriter(ostream), true) {
@@ -977,19 +985,19 @@ object SparkILoop {
       }
     }
   }
-  
+
   /** Creates an interpreter loop with default settings and feeds
    *  the given code to it as input.
    */
   def run(code: String, sets: Settings = new Settings): String = {
     import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-    
+
     stringFromStream { ostream =>
       Console.withOut(ostream) {
         val input    = new BufferedReader(new StringReader(code))
         val output   = new PrintWriter(new OutputStreamWriter(ostream), true)
         val repl     = new SparkILoop(input, output)
-        
+
         if (sets.classpath.isDefault)
           sets.classpath.value = sys.props("java.class.path")
 
@@ -1017,7 +1025,7 @@ object SparkILoop {
     repl.settings.embeddedDefaults[T]
     repl.createInterpreter()
     repl.in = SparkJLineReader(repl)
-    
+
     // rebind exit so people don't accidentally call sys.exit by way of predef
     repl.quietRun("""def exit = println("Type :quit to resume program execution.")""")
     args foreach (p => repl.bind(p.name, p.tpe, p.value))
@@ -1025,5 +1033,5 @@ object SparkILoop {
 
     echo("\nDebug repl exiting.")
     repl.closeInterpreter()
-  }  
+  }
 }