Skip to content
Snippets Groups Projects
Commit f6c94620 authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #58 from jegonzal/KryoMessages

Kryo messages
parents bac7be30 6083e435
No related branches found
No related tags found
No related merge requests found
Showing
with 361 additions and 108 deletions
......@@ -157,6 +157,16 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
/** Return the value at the specified position. */
def getValue(pos: Int): T = _data(pos)
def iterator() = new Iterator[T] {
var pos = nextPos(0)
override def hasNext: Boolean = pos != INVALID_POS
override def next(): T = {
val tmp = getValue(pos)
pos = nextPos(pos+1)
tmp
}
}
/** Return the value at the specified position. */
def getValueSafe(pos: Int): T = {
assert(_bitset.get(pos))
......
package org.apache.spark.graph
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
* The Graph abstractly represents a graph with arbitrary objects
......@@ -12,21 +12,21 @@ import org.apache.spark.rdd.RDD
* operations return new graphs.
*
* @see GraphOps for additional graph member functions.
*
*
* @note The majority of the graph operations are implemented in
* `GraphOps`. All the convenience operations are defined in the
* `GraphOps` class which may be shared across multiple graph
* implementations.
*
* @tparam VD the vertex attribute type
* @tparam ED the edge attribute type
* @tparam ED the edge attribute type
*/
abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
/**
* Get the vertices and their data.
*
* @note vertex ids are unique.
* @note vertex ids are unique.
* @return An RDD containing the vertices in this graph
*
* @see Vertex for the vertex type.
......@@ -70,6 +70,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
*/
val triplets: RDD[EdgeTriplet[VD, ED]]
def persist(newLevel: StorageLevel): Graph[VD, ED]
/**
* Return a graph that is cached when first created. This is used to
* pin a graph in memory enabling multiple queries to reuse the same
......@@ -100,7 +105,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
* @tparam VD2 the new vertex data type
*
* @example We might use this operation to change the vertex values
* from one type to another to initialize an algorithm.
* from one type to another to initialize an algorithm.
* {{{
* val rawGraph: Graph[(), ()] = Graph.textFile("hdfs://file")
* val root = 42
......@@ -190,7 +195,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
* @return the subgraph containing only the vertices and edges that
* satisfy the predicates.
*/
def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
vpred: (Vid, VD) => Boolean = ((v,d) => true) ): Graph[VD, ED]
......@@ -255,12 +260,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
* @param reduceFunc the user defined reduce function which should
* be commutative and assosciative and is used to combine the output
* of the map phase.
*
*
* @example We can use this function to compute the inDegree of each
* vertex
* {{{
* val rawGraph: Graph[(),()] = Graph.textFile("twittergraph")
* val inDeg: RDD[(Vid, Int)] =
* val inDeg: RDD[(Vid, Int)] =
* mapReduceTriplets[Int](et => Array((et.dst.id, 1)), _ + _)
* }}}
*
......@@ -269,12 +274,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
* Graph API in that enables neighborhood level computation. For
* example this function can be used to count neighbors satisfying a
* predicate or implement PageRank.
*
*
*/
def mapReduceTriplets[A: ClassManifest](
mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
reduceFunc: (A, A) => A)
: VertexSetRDD[A]
: VertexSetRDD[A]
/**
......@@ -296,11 +301,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
* @example This function is used to update the vertices with new
* values based on external data. For example we could add the out
* degree to each vertex record
*
*
* {{{
* val rawGraph: Graph[(),()] = Graph.textFile("webgraph")
* val outDeg: RDD[(Vid, Int)] = rawGraph.outDegrees()
* val graph = rawGraph.outerJoinVertices(outDeg) {
* val graph = rawGraph.outerJoinVertices(outDeg) {
* (vid, data, optDeg) => optDeg.getOrElse(0)
* }
* }}}
......@@ -337,7 +342,7 @@ object Graph {
* (i.e., the undirected degree).
*
* @param rawEdges the RDD containing the set of edges in the graph
*
*
* @return a graph with edge attributes containing the count of
* duplicate edges and vertex attributes containing the total degree
* of each vertex.
......@@ -368,10 +373,10 @@ object Graph {
rawEdges.map { case (s, t) => Edge(s, t, 1) }
}
// Determine unique vertices
/** @todo Should this reduceByKey operation be indexed? */
val vertices: RDD[(Vid, Int)] =
/** @todo Should this reduceByKey operation be indexed? */
val vertices: RDD[(Vid, Int)] =
edges.flatMap{ case Edge(s, t, cnt) => Array((s, 1), (t, 1)) }.reduceByKey(_ + _)
// Return graph
GraphImpl(vertices, edges, 0)
}
......@@ -392,7 +397,7 @@ object Graph {
*
*/
def apply[VD: ClassManifest, ED: ClassManifest](
vertices: RDD[(Vid,VD)],
vertices: RDD[(Vid,VD)],
edges: RDD[Edge[ED]]): Graph[VD, ED] = {
val defaultAttr: VD = null.asInstanceOf[VD]
Graph(vertices, edges, defaultAttr, (a:VD,b:VD) => a)
......@@ -416,7 +421,7 @@ object Graph {
*
*/
def apply[VD: ClassManifest, ED: ClassManifest](
vertices: RDD[(Vid,VD)],
vertices: RDD[(Vid,VD)],
edges: RDD[Edge[ED]],
defaultVertexAttr: VD,
mergeFunc: (VD, VD) => VD): Graph[VD, ED] = {
......
......@@ -2,7 +2,7 @@ package org.apache.spark.graph
import com.esotericsoftware.kryo.Kryo
import org.apache.spark.graph.impl.{EdgePartition, MessageToPartition}
import org.apache.spark.graph.impl._
import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.util.collection.BitSet
......@@ -12,6 +12,8 @@ class GraphKryoRegistrator extends KryoRegistrator {
kryo.register(classOf[Edge[Object]])
kryo.register(classOf[MutableTuple2[Object, Object]])
kryo.register(classOf[MessageToPartition[Object]])
kryo.register(classOf[VertexBroadcastMsg[Object]])
kryo.register(classOf[AggregationMsg[Object]])
kryo.register(classOf[(Vid, Object)])
kryo.register(classOf[EdgePartition[Object]])
kryo.register(classOf[BitSet])
......
......@@ -98,14 +98,14 @@ object Pregel {
: Graph[VD, ED] = {
// Receive the first set of messages
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg)).cache
var i = 0
while (i < numIter) {
// compute the messages
val messages = g.mapReduceTriplets(sendMsg, mergeMsg)
// receive the messages
g = g.joinVertices(messages)(vprog)
g = g.joinVertices(messages)(vprog).cache
// count the iteration
i += 1
}
......
......@@ -13,6 +13,7 @@ import org.apache.spark.graph._
import org.apache.spark.graph.impl.GraphImpl._
import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
......@@ -95,13 +96,17 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
makeTriplets(localVidMap, vTableReplicatedValues, eTable)
override def cache(): Graph[VD, ED] = {
eTable.cache()
vid2pid.cache()
vTable.cache()
override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
eTable.persist(newLevel)
vid2pid.persist(newLevel)
vTable.persist(newLevel)
localVidMap.persist(newLevel)
// vTableReplicatedValues.persist(newLevel)
this
}
override def cache(): Graph[VD, ED] = persist(StorageLevel.MEMORY_ONLY)
override def statistics: Map[String, Any] = {
val numVertices = this.numVertices
val numEdges = this.numEdges
......@@ -371,7 +376,7 @@ object GraphImpl {
val vSet = new VertexSet
edgePartition.foreach(e => {vSet.add(e.srcId); vSet.add(e.dstId)})
vSet.iterator.map { vid => (vid.toLong, pid) }
}
}.partitionBy(vTableIndex.rdd.partitioner.get)
VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
(p: Pid) => ArrayBuffer(p),
(ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab},
......@@ -508,7 +513,7 @@ object GraphImpl {
}
}.partitionBy(g.vTable.index.rdd.partitioner.get)
// do the final reduction reusing the index map
VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
VertexSetRDD.aggregate(preAgg, g.vTable.index, reduceFunc)
}
protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
......
......@@ -55,6 +55,8 @@ class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcast
// Set a custom serializer if the data is of int or double type.
if (classManifest[T] == ClassManifest.Int) {
rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Long) {
rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Double) {
rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
}
......@@ -70,6 +72,8 @@ class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[
// Set a custom serializer if the data is of int or double type.
if (classManifest[T] == ClassManifest.Int) {
rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Long) {
rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Double) {
rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
}
......
......@@ -27,6 +27,28 @@ class IntVertexBroadcastMsgSerializer extends Serializer {
}
}
/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
class LongVertexBroadcastMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
writeLong(msg.vid)
writeLong(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T](): T = {
val a = readLong()
val b = readLong()
new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
}
}
}
}
/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
class DoubleVertexBroadcastMsgSerializer extends Serializer {
......@@ -43,7 +65,9 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer {
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
def readObject[T](): T = {
new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T]
val a = readLong()
val b = readDouble()
new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
}
}
}
......@@ -65,7 +89,32 @@ class IntAggMsgSerializer extends Serializer {
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T](): T = {
new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T]
val a = readLong()
val b = readInt()
new AggregationMsg[Int](a, b).asInstanceOf[T]
}
}
}
}
/** A special shuffle serializer for AggregationMessage[Long]. */
class LongAggMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[AggregationMsg[Long]]
writeLong(msg.vid)
writeLong(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T](): T = {
val a = readLong()
val b = readLong()
new AggregationMsg[Long](a, b).asInstanceOf[T]
}
}
}
......@@ -87,7 +136,9 @@ class DoubleAggMsgSerializer extends Serializer {
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
def readObject[T](): T = {
new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T]
val a = readLong()
val b = readDouble()
new AggregationMsg[Double](a, b).asInstanceOf[T]
}
}
}
......
......@@ -8,10 +8,9 @@ package object graph {
type Vid = Long
type Pid = Int
type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T]
type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet
type VertexSet = OpenHashSet[Vid]
type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
// type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
type VertexIdToIndexMap = OpenHashSet[Vid]
......
package org.apache.spark.graph
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.graph.LocalSparkContext._
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import org.apache.spark.graph.impl._
import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark._
class SerializerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
test("TestVertexBroadcastMessageInt") {
val outMsg = new VertexBroadcastMsg[Int](3,4,5)
val bout = new ByteArrayOutputStream
val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestVertexBroadcastMessageLong") {
val outMsg = new VertexBroadcastMsg[Long](3,4,5)
val bout = new ByteArrayOutputStream
val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestVertexBroadcastMessageDouble") {
val outMsg = new VertexBroadcastMsg[Double](3,4,5.0)
val bout = new ByteArrayOutputStream
val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestAggregationMessageInt") {
val outMsg = new AggregationMsg[Int](4,5)
val bout = new ByteArrayOutputStream
val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: AggregationMsg[Int] = inStrm.readObject()
val inMsg2: AggregationMsg[Int] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestAggregationMessageLong") {
val outMsg = new AggregationMsg[Long](4,5)
val bout = new ByteArrayOutputStream
val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: AggregationMsg[Long] = inStrm.readObject()
val inMsg2: AggregationMsg[Long] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestAggregationMessageDouble") {
val outMsg = new AggregationMsg[Double](4,5.0)
val bout = new ByteArrayOutputStream
val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: AggregationMsg[Double] = inStrm.readObject()
val inMsg2: AggregationMsg[Double] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)
}
test("TestShuffleVertexBroadcastMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize(
(0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10)
val partitioner = new HashPartitioner(3)
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
}
}
test("TestShuffleAggregationMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize(
(0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10)
val partitioner = new HashPartitioner(3)
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment