From 796977acdb5c96ca5c08591657137fb3e44d2e94 Mon Sep 17 00:00:00 2001
From: Patrick Wendell <pwendell@gmail.com>
Date: Mon, 17 Mar 2014 14:03:32 -0700
Subject: [PATCH] SPARK-1244: Throw exception if map output status exceeds
 frame size

This is a very small change on top of @andrewor14's patch in #147.

Author: Patrick Wendell <pwendell@gmail.com>
Author: Andrew Or <andrewor14@gmail.com>

Closes #152 from pwendell/akka-frame and squashes the following commits:

e5fb3ff [Patrick Wendell] Reversing test order
393af4c [Patrick Wendell] Small improvement suggested by Andrew Or
8045103 [Patrick Wendell] Breaking out into two tests
2b4e085 [Patrick Wendell] Consolidate Executor use of akka frame size
c9b6109 [Andrew Or] Simplify test + make access to akka frame size more modular
281d7c9 [Andrew Or] Throw exception on spark.akka.frameSize exceeded + Unit tests
---
 .../org/apache/spark/MapOutputTracker.scala   | 19 +++++-
 .../scala/org/apache/spark/SparkEnv.scala     |  2 +-
 .../org/apache/spark/executor/Executor.scala  |  6 +-
 .../org/apache/spark/util/AkkaUtils.scala     |  9 ++-
 .../org/apache/spark/AkkaUtilsSuite.scala     | 10 ++--
 .../apache/spark/MapOutputTrackerSuite.scala  | 58 +++++++++++++++++--
 6 files changed, 84 insertions(+), 20 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5968973132..80cbf951cb 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -35,13 +35,28 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
 private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
 
-private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
   extends Actor with Logging {
+  val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+
   def receive = {
     case GetMapOutputStatuses(shuffleId: Int) =>
       val hostPort = sender.path.address.hostPort
       logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
-      sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
+      val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
+      val serializedSize = mapOutputStatuses.size
+      if (serializedSize > maxAkkaFrameSize) {
+        val msg = s"Map output statuses were $serializedSize bytes which " +
+          s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
+
+        /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
+         * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
+         * will ultimately remove this entire code path. */
+        val exception = new SparkException(msg)
+        logError(msg, exception)
+        throw exception
+      }
+      sender ! mapOutputStatuses
 
     case StopMapOutputTracker =>
       logInfo("MapOutputTrackerActor stopped!")
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index d035d909b7..774cbd6441 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -191,7 +191,7 @@ object SparkEnv extends Logging {
     }
     mapOutputTracker.trackerActor = registerOrLookup(
       "MapOutputTracker",
-      new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
+      new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
 
     val shuffleFetcher = instantiateClass[ShuffleFetcher](
       "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e69f6f72d3..2ea2ec29f5 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.scheduler._
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
 
 /**
  * Spark executor used with Mesos, YARN, and the standalone scheduler.
@@ -120,9 +120,7 @@ private[spark] class Executor(
 
   // Akka's message frame size. If task result is bigger than this, we use the block manager
   // to send the result back.
-  private val akkaFrameSize = {
-    env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
-  }
+  private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
 
   // Start worker thread pool
   val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index a6c9a9aaba..d0ff17db63 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -49,7 +49,7 @@ private[spark] object AkkaUtils extends Logging {
 
     val akkaTimeout = conf.getInt("spark.akka.timeout", 100)
 
-    val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
+    val akkaFrameSize = maxFrameSizeBytes(conf)
     val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
     val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
     if (!akkaLogLifecycleEvents) {
@@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging {
       |akka.remote.netty.tcp.port = $port
       |akka.remote.netty.tcp.tcp-nodelay = on
       |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
-      |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
+      |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B
       |akka.remote.netty.tcp.execution-pool-size = $akkaThreads
       |akka.actor.default-dispatcher.throughput = $akkaBatchSize
       |akka.log-config-on-start = $logAkkaConfig
@@ -121,4 +121,9 @@ private[spark] object AkkaUtils extends Logging {
   def lookupTimeout(conf: SparkConf): FiniteDuration = {
     Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
   }
+
+  /** Returns the configured max frame size for Akka messages in bytes. */
+  def maxFrameSizeBytes(conf: SparkConf): Int = {
+    conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
index cd054c1f68..d2e303d81c 100644
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -45,12 +45,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val masterTracker = new MapOutputTrackerMaster(conf)
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+        Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
 
     val badconf = new SparkConf
     badconf.set("spark.authenticate", "true")
     badconf.set("spark.authenticate.secret", "bad")
-    val securityManagerBad = new SecurityManager(badconf);
+    val securityManagerBad = new SecurityManager(badconf)
 
     assert(securityManagerBad.isAuthenticationEnabled() === true)
 
@@ -84,7 +84,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val masterTracker = new MapOutputTrackerMaster(conf)
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+        Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
 
     val badconf = new SparkConf
     badconf.set("spark.authenticate", "false")
@@ -136,7 +136,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val masterTracker = new MapOutputTrackerMaster(conf)
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+        Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
 
     val goodconf = new SparkConf
     goodconf.set("spark.authenticate", "true")
@@ -189,7 +189,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val masterTracker = new MapOutputTrackerMaster(conf)
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+        Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
 
     val badconf = new SparkConf
     badconf.set("spark.authenticate", "false")
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 8efa072a97..a5bd72eb0a 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import scala.concurrent.Await
 
 import akka.actor._
+import akka.testkit.TestActorRef
 import org.scalatest.FunSuite
 
 import org.apache.spark.scheduler.MapStatus
@@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
   test("master start and stop") {
     val actorSystem = ActorSystem("test")
     val tracker = new MapOutputTrackerMaster(conf)
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+    tracker.trackerActor =
+      actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
     tracker.stop()
   }
 
   test("master register and fetch") {
     val actorSystem = ActorSystem("test")
     val tracker = new MapOutputTrackerMaster(conf)
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+    tracker.trackerActor =
+      actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
   test("master register and unregister and fetch") {
     val actorSystem = ActorSystem("test")
     val tracker = new MapOutputTrackerMaster(conf)
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+    tracker.trackerActor =
+      actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -100,11 +104,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     val hostname = "localhost"
     val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
       securityManager = new SecurityManager(conf))
-    System.setProperty("spark.driver.port", boundPort.toString)    // Will be cleared by LocalSparkContext
+
+    // Will be cleared by LocalSparkContext
+    System.setProperty("spark.driver.port", boundPort.toString)
 
     val masterTracker = new MapOutputTrackerMaster(conf)
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+      Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
       securityManager = new SecurityManager(conf))
@@ -126,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
-           Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+      Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
 
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
     masterTracker.incrementEpoch()
@@ -136,4 +142,44 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     // failure should be cached
     intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
   }
+
+  test("remote fetch below akka frame size") {
+    val newConf = new SparkConf
+    newConf.set("spark.akka.frameSize", "1")
+    newConf.set("spark.akka.askTimeout", "1") // Fail fast
+
+    val masterTracker = new MapOutputTrackerMaster(conf)
+    val actorSystem = ActorSystem("test")
+    val actorRef = TestActorRef[MapOutputTrackerMasterActor](
+      new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+    val masterActor = actorRef.underlyingActor
+
+    // Frame size should be ~123B, and no exception should be thrown
+    masterTracker.registerShuffle(10, 1)
+    masterTracker.registerMapOutput(10, 0, new MapStatus(
+      BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
+    masterActor.receive(GetMapOutputStatuses(10))
+  }
+
+  test("remote fetch exceeds akka frame size") {
+    val newConf = new SparkConf
+    newConf.set("spark.akka.frameSize", "1")
+    newConf.set("spark.akka.askTimeout", "1") // Fail fast
+
+    val masterTracker = new MapOutputTrackerMaster(conf)
+    val actorSystem = ActorSystem("test")
+    val actorRef = TestActorRef[MapOutputTrackerMasterActor](
+      new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+    val masterActor = actorRef.underlyingActor
+
+    // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
+    // Note that the size is hand-selected here because map output statuses are compressed before
+    // being sent.
+    masterTracker.registerShuffle(20, 100)
+    (0 until 100).foreach { i =>
+      masterTracker.registerMapOutput(20, i, new MapStatus(
+        BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
+    }
+    intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
+  }
 }
-- 
GitLab