diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index fac416a5b3d5d24ca16ab7652c145c2e44302d15..bb78207c9f3c2d924a12e909c524b24870572c5a 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -67,11 +67,20 @@ object BlockFetcherIterator {
       throw new IllegalArgumentException("BlocksByAddress is null")
     }
 
-    protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
-    logDebug("Getting " + _totalBlocks + " blocks")
+    // Total number blocks fetched (local + remote). Also number of FetchResults expected
+    protected var _numBlocksToFetch = 0
+
     protected var startTime = System.currentTimeMillis
-    protected val localBlockIds = new ArrayBuffer[String]()
-    protected val remoteBlockIds = new HashSet[String]()
+
+    // This represents the number of local blocks, also counting zero-sized blocks
+    private var numLocal = 0
+    // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+    protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+    // This represents the number of remote blocks, also counting zero-sized blocks
+    private var numRemote = 0
+    // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+    protected val remoteBlocksToFetch = new HashSet[String]()
 
     // A queue to hold our results.
     protected val results = new LinkedBlockingQueue[FetchResult]
@@ -124,13 +133,15 @@ object BlockFetcherIterator {
     protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
       // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in flight.
-      val originalTotalBlocks = _totalBlocks
       val remoteRequests = new ArrayBuffer[FetchRequest]
       for ((address, blockInfos) <- blocksByAddress) {
         if (address == blockManagerId) {
-          localBlockIds ++= blockInfos.map(_._1)
+          numLocal = blockInfos.size
+          // Filter out zero-sized blocks
+          localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+          _numBlocksToFetch += localBlocksToFetch.size
         } else {
-          remoteBlockIds ++= blockInfos.map(_._1)
+          numRemote += blockInfos.size
           // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
           // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
           // nodes, rather than blocking on reading output from one node.
@@ -144,10 +155,10 @@ object BlockFetcherIterator {
             // Skip empty blocks
             if (size > 0) {
               curBlocks += ((blockId, size))
+              remoteBlocksToFetch += blockId
+              _numBlocksToFetch += 1
               curRequestSize += size
-            } else if (size == 0) {
-              _totalBlocks -= 1
-            } else {
+            } else if (size < 0) {
               throw new BlockException(blockId, "Negative block size " + size)
             }
             if (curRequestSize >= minRequestSize) {
@@ -163,8 +174,8 @@ object BlockFetcherIterator {
           }
         }
       }
-      logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
-        originalTotalBlocks + " blocks")
+      logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+        totalBlocks + " blocks")
       remoteRequests
     }
 
@@ -172,7 +183,7 @@ object BlockFetcherIterator {
       // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
       // these all at once because they will just memory-map some files, so they won't consume
       // any memory that might exceed our maxBytesInFlight
-      for (id <- localBlockIds) {
+      for (id <- localBlocksToFetch) {
         getLocalFromDisk(id, serializer) match {
           case Some(iter) => {
             // Pass 0 as size since it's not in flight
@@ -198,7 +209,7 @@ object BlockFetcherIterator {
         sendRequest(fetchRequests.dequeue())
       }
 
-      val numGets = remoteBlockIds.size - fetchRequests.size
+      val numGets = remoteRequests.size - fetchRequests.size
       logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
 
       // Get Local Blocks
@@ -210,7 +221,7 @@ object BlockFetcherIterator {
     //an iterator that will read fetched blocks off the queue as they arrive.
     @volatile protected var resultsGotten = 0
 
-    override def hasNext: Boolean = resultsGotten < _totalBlocks
+    override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
 
     override def next(): (String, Option[Iterator[Any]]) = {
       resultsGotten += 1
@@ -227,9 +238,9 @@ object BlockFetcherIterator {
     }
 
     // Implementing BlockFetchTracker trait.
-    override def totalBlocks: Int = _totalBlocks
-    override def numLocalBlocks: Int = localBlockIds.size
-    override def numRemoteBlocks: Int = remoteBlockIds.size
+    override def totalBlocks: Int = numLocal + numRemote
+    override def numLocalBlocks: Int = numLocal
+    override def numRemoteBlocks: Int = numRemote
     override def remoteFetchTime: Long = _remoteFetchTime
     override def fetchWaitTime: Long = _fetchWaitTime
     override def remoteBytesRead: Long = _remoteBytesRead
@@ -291,7 +302,7 @@ object BlockFetcherIterator {
     private var copiers: List[_ <: Thread] = null
 
     override def initialize() {
-      // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+      // Split Local Remote Blocks and set numBlocksToFetch
       val remoteRequests = splitLocalRemoteBlocks()
       // Add the remote requests into our queue in a random order
       for (request <- Utils.randomize(remoteRequests)) {
@@ -313,7 +324,7 @@ object BlockFetcherIterator {
       val result = results.take()
       // if all the results has been retrieved, shutdown the copiers
       // NO need to stop the copiers if we got all the blocks ?
-      // if (resultsGotten == _totalBlocks && copiers != null) {
+      // if (resultsGotten == _numBlocksToFetch && copiers != null) {
       //   stopCopiers()
       // }
       (result.blockId, if (result.failed) None else Some(result.deserialize()))
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index c7281200e7e0086660e9cfeb65d288ce28f7875b..0af6e4a35917fd8551f15ab6c6d47c391824d4b2 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
     private var bs: OutputStream = null
     private var objOut: SerializationStream = null
     private var lastValidPosition = 0L
+    private var initialized = false
 
     override def open(): DiskBlockObjectWriter = {
       val fos = new FileOutputStream(f, true)
       channel = fos.getChannel()
-      bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
+      bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
       objOut = serializer.newInstance().serializeStream(bs)
+      initialized = true
       this
     }
 
     override def close() {
-      objOut.close()
-      bs.close()
-      channel = null
-      bs = null
-      objOut = null
+      if (initialized) {
+        objOut.close()
+        bs.close()
+        channel = null
+        bs = null
+        objOut = null
+      }
       // Invoke the close callback handler.
       super.close()
     }
@@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
     // Flush the partial writes, and set valid length to be the length of the entire file.
     // Return the number of bytes written for this commit.
     override def commit(): Long = {
-      // NOTE: Flush the serializer first and then the compressed/buffered output stream
-      objOut.flush()
-      bs.flush()
-      val prevPos = lastValidPosition
-      lastValidPosition = channel.position()
-      lastValidPosition - prevPos
+      if (initialized) {
+        // NOTE: Flush the serializer first and then the compressed/buffered output stream
+        objOut.flush()
+        bs.flush()
+        val prevPos = lastValidPosition
+        lastValidPosition = channel.position()
+        lastValidPosition - prevPos
+      } else {
+        // lastValidPosition is zero if stream is uninitialized
+        lastValidPosition
+      }
     }
 
     override def revertPartialWrites() {
-      // Discard current writes. We do this by flushing the outstanding writes and
-      // truncate the file to the last valid position.
-      objOut.flush()
-      bs.flush()
-      channel.truncate(lastValidPosition)
+      if (initialized) { 
+        // Discard current writes. We do this by flushing the outstanding writes and
+        // truncate the file to the last valid position.
+        objOut.flush()
+        bs.flush()
+        channel.truncate(lastValidPosition)
+      }
     }
 
     override def write(value: Any) {
+      if (!initialized) {
+        open()
+      }
       objOut.writeObject(value)
     }
 
@@ -197,7 +211,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
   private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
     val file = getFile(blockId)
     if (!allowAppendExisting && file.exists()) {
-      throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+      // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+      // was rescheduled on the same machine as the old task ?
+      logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+      file.delete()
+      // throw new Exception("File for block " + blockId + " already exists on disk: " + file)
     }
     file
   }
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 49eabfb0d21505616be4f72531fe074118b3a53c..44638e0c2d8f9be93af4f0904037db9921c4a473 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
         val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
         val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
-          blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+          blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
         }
         new ShuffleWriterGroup(mapId, writers)
       }
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index b967016cf726791b543781a9f42cf8c9607aab71..33b02fff801445a678b5ab4dfde1014bc8d59f44 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     assert(nonEmptyBlocks.size <= 4)
   }
 
+  test("zero sized blocks without kryo") {
+    // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+    sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+    // 10 partitions from 4 keys
+    val NUM_BLOCKS = 10
+    val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+    val b = a.map(x => (x, x*2))
+
+    // NOTE: The default Java serializer doesn't create zero-sized blocks.
+    //       So, use Kryo
+    val c = new ShuffledRDD(b, new HashPartitioner(10))
+
+    val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+    assert(c.count === 4)
+
+    val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+      val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+      statuses.map(x => x._2)
+    }
+    val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+    // We should have at most 4 non-zero sized partitions
+    assert(nonEmptyBlocks.size <= 4)
+  }
+
 }
 
 object ShuffleSuite {