From e02dc83a5b4088bda868868fa63c8b89ea9789be Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Mon, 6 Feb 2012 20:40:39 -0800
Subject: [PATCH] IO optimizations

---
 core/src/main/scala/spark/KryoSerializer.scala         | 4 +++-
 core/src/main/scala/spark/ParallelShuffleFetcher.scala | 4 +++-
 core/src/main/scala/spark/ShuffleMapTask.scala         | 4 +++-
 core/src/main/scala/spark/SimpleShuffleFetcher.scala   | 5 ++++-
 4 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 9ec06a6e21..7d25b965d2 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -65,11 +65,13 @@ object ZigZag {
 
 class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
 extends SerializationStream {
+  val channel = Channels.newChannel(out)
+
   def writeObject[T](t: T) {
     kryo.writeClassAndObject(buf, t)
     ZigZag.writeInt(buf.position(), out)
     buf.flip()
-    Channels.newChannel(out).write(buf)
+    channel.write(buf)
     buf.clear()
   }
 
diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
index 95dfb01aac..98eb37934b 100644
--- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
@@ -12,6 +12,8 @@ import java.util.concurrent.atomic.AtomicReference
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
 
 class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
   val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
@@ -60,7 +62,7 @@ class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
                 if (len == -1)
                   throw new SparkException("Content length was not specified by server")
                 val buf = new Array[Byte](len)
-                val in = conn.getInputStream()
+                val in = new FastBufferedInputStream(conn.getInputStream())
                 var pos = 0
                 while (pos < len) {
                   val n = in.read(buf, pos, len-pos)
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
index 7b08a21fca..93a93d5750 100644
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/ShuffleMapTask.scala
@@ -5,6 +5,8 @@ import java.io.FileOutputStream
 import java.io.ObjectOutputStream
 import java.util.{HashMap => JHashMap}
 
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
 
 class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
 extends DAGTask[String](stageId) with Logging {
@@ -29,7 +31,7 @@ extends DAGTask[String](stageId) with Logging {
     val ser = SparkEnv.get.serializer.newInstance()
     for (i <- 0 until numOutputSplits) {
       val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
-      val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file)))
+      val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
       val iter = buckets(i).entrySet().iterator()
       while (iter.hasNext()) {
         val entry = iter.next()
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
index 1cc0cfc331..1e38a2b1db 100644
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
@@ -6,6 +6,8 @@ import java.net.URL
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
 
 class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
   def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
@@ -22,7 +24,8 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
           val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
           // TODO: multithreaded fetch
           // TODO: would be nice to retry multiple times
-          val inputStream = ser.inputStream(new URL(url).openStream())
+          val inputStream = ser.inputStream(
+              new FastBufferedInputStream(new URL(url).openStream()))
           try {
             while (true) {
               val pair = inputStream.readObject().asInstanceOf[(K, V)]
-- 
GitLab