Skip to content
Snippets Groups Projects
Commit 0e813cd4 authored by Reynold Xin's avatar Reynold Xin
Browse files

Fix the hanging bug.

parent f6c94620
No related branches found
No related tags found
No related merge requests found
package org.apache.spark.graph.impl package org.apache.spark.graph.impl
import java.io.{InputStream, OutputStream} import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer import java.nio.ByteBuffer
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer} import org.apache.spark.serializer._
/** A special shuffle serializer for VertexBroadcastMessage[Int]. */ /** A special shuffle serializer for VertexBroadcastMessage[Int]. */
...@@ -185,11 +185,15 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser ...@@ -185,11 +185,15 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser
def readObject[T](): T def readObject[T](): T
def readInt(): Int = { def readInt(): Int = {
(s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) val first = s.read()
if (first < 0) throw new EOFException
(first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
} }
def readLong(): Long = { def readLong(): Long = {
(s.read().toLong << 56) | val first = s.read()
if (first < 0) throw new EOFException()
(first.toLong << 56) |
(s.read() & 0xFF).toLong << 48 | (s.read() & 0xFF).toLong << 48 |
(s.read() & 0xFF).toLong << 40 | (s.read() & 0xFF).toLong << 40 |
(s.read() & 0xFF).toLong << 32 | (s.read() & 0xFF).toLong << 32 |
......
...@@ -4,8 +4,7 @@ import org.scalatest.FunSuite ...@@ -4,8 +4,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.graph.LocalSparkContext._ import org.apache.spark.graph.LocalSparkContext._
import java.io.ByteArrayInputStream import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
import java.io.ByteArrayOutputStream
import org.apache.spark.graph.impl._ import org.apache.spark.graph.impl._
import org.apache.spark.graph.impl.MsgRDDFunctions._ import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark._ import org.apache.spark._
...@@ -31,6 +30,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -31,6 +30,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestVertexBroadcastMessageLong") { test("TestVertexBroadcastMessageLong") {
...@@ -48,6 +51,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -48,6 +51,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestVertexBroadcastMessageDouble") { test("TestVertexBroadcastMessageDouble") {
...@@ -65,6 +72,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -65,6 +72,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageInt") { test("TestAggregationMessageInt") {
...@@ -82,6 +93,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -82,6 +93,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageLong") { test("TestAggregationMessageLong") {
...@@ -99,6 +114,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -99,6 +114,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageDouble") { test("TestAggregationMessageDouble") {
...@@ -116,23 +135,25 @@ class SerializerSuite extends FunSuite with LocalSparkContext { ...@@ -116,23 +135,25 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestShuffleVertexBroadcastMsg") { test("TestShuffleVertexBroadcastMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc => withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize( val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
(0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10) new VertexBroadcastMsg[Int](pid, pid, pid)
val partitioner = new HashPartitioner(3) }
val bmsgsArray = bmsgs.partitionBy(partitioner).collect bmsgs.partitionBy(new HashPartitioner(3)).collect()
} }
} }
test("TestShuffleAggregationMsg") { test("TestShuffleAggregationMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc => withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize( val bmsgs = sc.parallelize(0 until 100, 10).map(pid => new AggregationMsg[Int](pid, pid))
(0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10) bmsgs.partitionBy(new HashPartitioner(3)).collect()
val partitioner = new HashPartitioner(3)
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment