Skip to content
Snippets Groups Projects
Commit 91c07a33 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Various fixes to serialization

parent f61b61c4
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,10 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
*/
object ZigZag {
def writeInt(n: Int, out: OutputStream) {
var value = n
......@@ -110,12 +114,15 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
val bufferSize =
System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
val threadBuf = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, 257*1024*1024)
override def initialValue = new ObjectBuffer(kryo, bufferSize)
}
val threadByteBuf = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(257*1024*1024)
override def initialValue = ByteBuffer.allocate(bufferSize)
}
def createKryo(): Kryo = {
......
......@@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap
class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
extends DAGTask[String](stageId) {
extends DAGTask[String](stageId) with Logging {
val split = rdd.splits(partition)
override def run: String = {
......@@ -23,11 +23,12 @@ extends DAGTask[String](stageId) {
case None => aggregator.createCombiner(v)
}
}
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
// TODO: use Serializer instead of ObjectInputStream
// TODO: have some kind of EOF marker
val out = new ObjectOutputStream(new FileOutputStream(file))
val out = ser.outputStream(new FileOutputStream(file))
buckets(i).foreach(pair => out.writeObject(pair))
out.close()
}
......
......@@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap
class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val ser = SparkEnv.get.serializer.newInstance()
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) {
......@@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
for (i <- inputIds) {
try {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
// TODO: use Serializer instead of ObjectInputStream
// TODO: multithreaded fetch
// TODO: would be nice to retry multiple times
val inputStream = new ObjectInputStream(new URL(url).openStream())
val inputStream = ser.inputStream(new URL(url).openStream())
try {
while (true) {
val pair = inputStream.readObject().asInstanceOf[(K, V)]
......
......@@ -19,12 +19,10 @@ object SparkEnv {
}
def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
val cacheClass = System.getProperty("spark.cache.class",
"spark.SoftReferenceCache")
val cacheClass = System.getProperty("spark.cache.class", "spark.SoftReferenceCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val serClass = System.getProperty("spark.serializer",
"spark.JavaSerializer")
val serClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer]
val cacheTracker = new CacheTracker(isMaster, cache)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment