diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala index f65f96ed0c1b2dd62f1d5bf0b20b77cb060267af..82b9198e432c728179de848b2b8b1672e7dbbc4d 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala @@ -2,7 +2,7 @@ package org.apache.spark.graph import com.esotericsoftware.kryo.Kryo -import org.apache.spark.graph.impl.{EdgePartition, MessageToPartition} +import org.apache.spark.graph.impl._ import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.util.collection.BitSet @@ -12,6 +12,8 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[Edge[Object]]) kryo.register(classOf[MutableTuple2[Object, Object]]) kryo.register(classOf[MessageToPartition[Object]]) + kryo.register(classOf[VertexBroadcastMsg[Object]]) + kryo.register(classOf[AggregationMsg[Object]]) kryo.register(classOf[(Vid, Object)]) kryo.register(classOf[EdgePartition[Object]]) kryo.register(classOf[BitSet]) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala index 3fc0b7c0f7588d3b2aff84868367edc7f45dcd39..d0a5adb85cd8a3d6bb28ff588e527465a0b3c3dc 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala @@ -55,6 +55,8 @@ class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcast // Set a custom serializer if the data is of int or double type. if (classManifest[T] == ClassManifest.Int) { rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName) + } else if (classManifest[T] == ClassManifest.Long) { + rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName) } else if (classManifest[T] == ClassManifest.Double) { rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName) } @@ -70,6 +72,8 @@ class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[ // Set a custom serializer if the data is of int or double type. if (classManifest[T] == ClassManifest.Int) { rdd.setSerializer(classOf[IntAggMsgSerializer].getName) + } else if (classManifest[T] == ClassManifest.Long) { + rdd.setSerializer(classOf[LongAggMsgSerializer].getName) } else if (classManifest[T] == ClassManifest.Double) { rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName) } diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala index 8b4c0868b1a6dff011cab21938f60fc1632c6ece..54fd65e7381f2512a8c069bc91abf95f7d58c320 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala @@ -27,6 +27,28 @@ class IntVertexBroadcastMsgSerializer extends Serializer { } } +/** A special shuffle serializer for VertexBroadcastMessage[Long]. */ +class LongVertexBroadcastMsgSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[VertexBroadcastMsg[Long]] + writeLong(msg.vid) + writeLong(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + override def readObject[T](): T = { + val a = readLong() + val b = readLong() + new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T] + } + } + } +} /** A special shuffle serializer for VertexBroadcastMessage[Double]. */ class DoubleVertexBroadcastMsgSerializer extends Serializer { @@ -43,7 +65,9 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer { override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { def readObject[T](): T = { - new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T] + val a = readLong() + val b = readDouble() + new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T] } } } @@ -65,7 +89,32 @@ class IntAggMsgSerializer extends Serializer { override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { override def readObject[T](): T = { - new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T] + val a = readLong() + val b = readInt() + new AggregationMsg[Int](a, b).asInstanceOf[T] + } + } + } +} + +/** A special shuffle serializer for AggregationMessage[Long]. */ +class LongAggMsgSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[AggregationMsg[Long]] + writeLong(msg.vid) + writeLong(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + override def readObject[T](): T = { + val a = readLong() + val b = readLong() + new AggregationMsg[Long](a, b).asInstanceOf[T] } } } @@ -87,7 +136,9 @@ class DoubleAggMsgSerializer extends Serializer { override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { def readObject[T](): T = { - new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T] + val a = readLong() + val b = readDouble() + new AggregationMsg[Double](a, b).asInstanceOf[T] } } } diff --git a/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..5a59fd912a519a7e45ba504e4c4f3ae1608fee2d --- /dev/null +++ b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala @@ -0,0 +1,139 @@ +package org.apache.spark.graph + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.graph.LocalSparkContext._ +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import org.apache.spark.graph.impl._ +import org.apache.spark.graph.impl.MsgRDDFunctions._ +import org.apache.spark._ + + +class SerializerSuite extends FunSuite with LocalSparkContext { + + System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator") + + test("TestVertexBroadcastMessageInt") { + val outMsg = new VertexBroadcastMsg[Int](3,4,5) + val bout = new ByteArrayOutputStream + val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject() + val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestVertexBroadcastMessageLong") { + val outMsg = new VertexBroadcastMsg[Long](3,4,5) + val bout = new ByteArrayOutputStream + val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject() + val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestVertexBroadcastMessageDouble") { + val outMsg = new VertexBroadcastMsg[Double](3,4,5.0) + val bout = new ByteArrayOutputStream + val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject() + val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestAggregationMessageInt") { + val outMsg = new AggregationMsg[Int](4,5) + val bout = new ByteArrayOutputStream + val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: AggregationMsg[Int] = inStrm.readObject() + val inMsg2: AggregationMsg[Int] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestAggregationMessageLong") { + val outMsg = new AggregationMsg[Long](4,5) + val bout = new ByteArrayOutputStream + val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: AggregationMsg[Long] = inStrm.readObject() + val inMsg2: AggregationMsg[Long] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestAggregationMessageDouble") { + val outMsg = new AggregationMsg[Double](4,5.0) + val bout = new ByteArrayOutputStream + val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) + outStrm.writeObject(outMsg) + outStrm.writeObject(outMsg) + bout.flush + val bin = new ByteArrayInputStream(bout.toByteArray) + val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) + val inMsg1: AggregationMsg[Double] = inStrm.readObject() + val inMsg2: AggregationMsg[Double] = inStrm.readObject() + assert(outMsg.vid === inMsg1.vid) + assert(outMsg.vid === inMsg2.vid) + assert(outMsg.data === inMsg1.data) + assert(outMsg.data === inMsg2.data) + } + + test("TestShuffleVertexBroadcastMsg") { + withSpark(new SparkContext("local[2]", "test")) { sc => + val bmsgs = sc.parallelize( + (0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10) + val partitioner = new HashPartitioner(3) + val bmsgsArray = bmsgs.partitionBy(partitioner).collect + } + } + + test("TestShuffleAggregationMsg") { + withSpark(new SparkContext("local[2]", "test")) { sc => + val bmsgs = sc.parallelize( + (0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10) + val partitioner = new HashPartitioner(3) + val bmsgsArray = bmsgs.partitionBy(partitioner).collect + } + } + +} \ No newline at end of file