From e02dc83a5b4088bda868868fa63c8b89ea9789be Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Mon, 6 Feb 2012 20:40:39 -0800 Subject: [PATCH] IO optimizations --- core/src/main/scala/spark/KryoSerializer.scala | 4 +++- core/src/main/scala/spark/ParallelShuffleFetcher.scala | 4 +++- core/src/main/scala/spark/ShuffleMapTask.scala | 4 +++- core/src/main/scala/spark/SimpleShuffleFetcher.scala | 5 ++++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 9ec06a6e21..7d25b965d2 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -65,11 +65,13 @@ object ZigZag { class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream) extends SerializationStream { + val channel = Channels.newChannel(out) + def writeObject[T](t: T) { kryo.writeClassAndObject(buf, t) ZigZag.writeInt(buf.position(), out) buf.flip() - Channels.newChannel(out).write(buf) + channel.write(buf) buf.clear() } diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala index 95dfb01aac..98eb37934b 100644 --- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala +++ b/core/src/main/scala/spark/ParallelShuffleFetcher.scala @@ -12,6 +12,8 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + class ParallelShuffleFetcher extends ShuffleFetcher with Logging { val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt @@ -60,7 +62,7 @@ class ParallelShuffleFetcher extends ShuffleFetcher with Logging { if (len == -1) throw new SparkException("Content length was not specified by server") val buf = new Array[Byte](len) - val in = conn.getInputStream() + val in = new FastBufferedInputStream(conn.getInputStream()) var pos = 0 while (pos < len) { val n = in.read(buf, pos, len-pos) diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala index 7b08a21fca..93a93d5750 100644 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -5,6 +5,8 @@ import java.io.FileOutputStream import java.io.ObjectOutputStream import java.util.{HashMap => JHashMap} +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String]) extends DAGTask[String](stageId) with Logging { @@ -29,7 +31,7 @@ extends DAGTask[String](stageId) with Logging { val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i) - val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file))) + val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file))) val iter = buckets(i).entrySet().iterator() while (iter.hasNext()) { val entry = iter.next() diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala index 1cc0cfc331..1e38a2b1db 100644 --- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala +++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala @@ -6,6 +6,8 @@ import java.net.URL import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + class SimpleShuffleFetcher extends ShuffleFetcher with Logging { def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { @@ -22,7 +24,8 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging { val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) // TODO: multithreaded fetch // TODO: would be nice to retry multiple times - val inputStream = ser.inputStream(new URL(url).openStream()) + val inputStream = ser.inputStream( + new FastBufferedInputStream(new URL(url).openStream())) try { while (true) { val pair = inputStream.readObject().asInstanceOf[(K, V)] -- GitLab