Skip to content
Snippets Groups Projects
Commit 91c07a33 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Various fixes to serialization

parent f61b61c4
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,10 @@ import scala.collection.mutable ...@@ -10,6 +10,10 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.{Serializer => KSerializer}
/**
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
*/
object ZigZag { object ZigZag {
def writeInt(n: Int, out: OutputStream) { def writeInt(n: Int, out: OutputStream) {
var value = n var value = n
...@@ -110,12 +114,15 @@ trait KryoRegistrator { ...@@ -110,12 +114,15 @@ trait KryoRegistrator {
class KryoSerializer extends Serializer with Logging { class KryoSerializer extends Serializer with Logging {
val kryo = createKryo() val kryo = createKryo()
val bufferSize =
System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
val threadBuf = new ThreadLocal[ObjectBuffer] { val threadBuf = new ThreadLocal[ObjectBuffer] {
override def initialValue = new ObjectBuffer(kryo, 257*1024*1024) override def initialValue = new ObjectBuffer(kryo, bufferSize)
} }
val threadByteBuf = new ThreadLocal[ByteBuffer] { val threadByteBuf = new ThreadLocal[ByteBuffer] {
override def initialValue = ByteBuffer.allocate(257*1024*1024) override def initialValue = ByteBuffer.allocate(bufferSize)
} }
def createKryo(): Kryo = { def createKryo(): Kryo = {
......
...@@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap ...@@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap
class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String]) class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
extends DAGTask[String](stageId) { extends DAGTask[String](stageId) with Logging {
val split = rdd.splits(partition) val split = rdd.splits(partition)
override def run: String = { override def run: String = {
...@@ -23,11 +23,12 @@ extends DAGTask[String](stageId) { ...@@ -23,11 +23,12 @@ extends DAGTask[String](stageId) {
case None => aggregator.createCombiner(v) case None => aggregator.createCombiner(v)
} }
} }
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) { for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i) val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
// TODO: use Serializer instead of ObjectInputStream // TODO: use Serializer instead of ObjectInputStream
// TODO: have some kind of EOF marker // TODO: have some kind of EOF marker
val out = new ObjectOutputStream(new FileOutputStream(file)) val out = ser.outputStream(new FileOutputStream(file))
buckets(i).foreach(pair => out.writeObject(pair)) buckets(i).foreach(pair => out.writeObject(pair))
out.close() out.close()
} }
......
...@@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap ...@@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap
class SimpleShuffleFetcher extends ShuffleFetcher with Logging { class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val ser = SparkEnv.get.serializer.newInstance()
val splitsByUri = new HashMap[String, ArrayBuffer[Int]] val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
for ((serverUri, index) <- serverUris.zipWithIndex) { for ((serverUri, index) <- serverUris.zipWithIndex) {
...@@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging { ...@@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
for (i <- inputIds) { for (i <- inputIds) {
try { try {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
// TODO: use Serializer instead of ObjectInputStream
// TODO: multithreaded fetch // TODO: multithreaded fetch
// TODO: would be nice to retry multiple times // TODO: would be nice to retry multiple times
val inputStream = new ObjectInputStream(new URL(url).openStream()) val inputStream = ser.inputStream(new URL(url).openStream())
try { try {
while (true) { while (true) {
val pair = inputStream.readObject().asInstanceOf[(K, V)] val pair = inputStream.readObject().asInstanceOf[(K, V)]
......
...@@ -19,12 +19,10 @@ object SparkEnv { ...@@ -19,12 +19,10 @@ object SparkEnv {
} }
def createFromSystemProperties(isMaster: Boolean): SparkEnv = { def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
val cacheClass = System.getProperty("spark.cache.class", val cacheClass = System.getProperty("spark.cache.class", "spark.SoftReferenceCache")
"spark.SoftReferenceCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val serClass = System.getProperty("spark.serializer", val serClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
"spark.JavaSerializer")
val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer] val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer]
val cacheTracker = new CacheTracker(isMaster, cache) val cacheTracker = new CacheTracker(isMaster, cache)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment