diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 42b2985b50ad79d8399a6ca1216e0c9fdc7b0705..fad54683bcf5a763e07989c0d9fa88eb639047cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -81,7 +81,7 @@ class SparkContext(
     val sparkHome: String = null,
     val jars: Seq[String] = Nil,
     val environment: Map[String, String] = Map(),
-    // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc)
+    // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc)
     // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set
     // of data-local splits on host
     val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 9c43a7287d6ee7e13c870c781bbff50a0d122f09..eeeca3ea8a33e4f77562383dfef50a7507628894 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -30,8 +30,10 @@ import org.apache.hadoop.net.NetUtils
 import org.apache.hadoop.security.UserGroupInformation
 import org.apache.hadoop.util.ShutdownHookManager
 import org.apache.hadoop.yarn.api._
-import org.apache.hadoop.yarn.api.records._
 import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import org.apache.hadoop.yarn.ipc.YarnRPC
 import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
@@ -45,55 +47,43 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
   def this(args: ApplicationMasterArguments) = this(args, new Configuration())
   
   private var rpc: YarnRPC = YarnRPC.create(conf)
-  private var resourceManager: AMRMProtocol = _
+  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
   private var appAttemptId: ApplicationAttemptId = _
   private var userThread: Thread = _
-  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
   private val fs = FileSystem.get(yarnConf)
 
   private var yarnAllocator: YarnAllocationHandler = _
   private var isFinished: Boolean = false
   private var uiAddress: String = _
-  private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES,
-    YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES)
+  private val maxAppAttempts: Int = conf.getInt(
+    YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
   private var isLastAMRetry: Boolean = true
-  // default to numWorkers * 2, with minimum of 3
+  private var amClient: AMRMClient[ContainerRequest] = _
+
+  // Default to numWorkers * 2, with minimum of 3
   private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures",
     math.max(args.numWorkers * 2, 3).toString()).toInt
 
   def run() {
-    // Setup the directories so things go to yarn approved directories rather
-    // then user specified and /tmp.
+    // Setup the directories so things go to YARN approved directories rather
+    // than user specified and /tmp.
     System.setProperty("spark.local.dir", getLocalDirs())
 
-    // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using.
+    // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using.
     ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
-    
+
     appAttemptId = getApplicationAttemptId()
     isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
-    resourceManager = registerWithResourceManager()
+    amClient = AMRMClient.createAMRMClient()
+    amClient.init(yarnConf)
+    amClient.start()
 
     // Workaround until hadoop moves to something which has
     // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
-    // ignore result.
-    // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
-    // Hence args.workerCores = numCore disabled above. Any better option?
-
-    // Compute number of threads for akka
-    //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
-    //if (minimumMemory > 0) {
-    //  val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
-    //  val numCore = (mem  / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
-
-    //  if (numCore > 0) {
-        // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
-        // TODO: Uncomment when hadoop is on a version which has this fixed.
-        // args.workerCores = numCore
-    //  }
-    //}
     // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
     
     ApplicationMaster.register(this)
+
     // Start the user's JAR
     userThread = startUserClass()
     
@@ -103,12 +93,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
 
     waitForSparkContextInitialized()
 
-    // Do this after spark master is up and SparkContext is created so that we can register UI Url
+    // Do this after Spark master is up and SparkContext is created so that we can register UI Url.
     val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
-    
+
     // Allocate all containers
     allocateWorkers()
-    
+
     // Wait for the user class to Finish     
     userThread.join()
 
@@ -132,41 +122,24 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
   
   private def getApplicationAttemptId(): ApplicationAttemptId = {
     val envs = System.getenv()
-    val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+    val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
     val containerId = ConverterUtils.toContainerId(containerIdString)
     val appAttemptId = containerId.getApplicationAttemptId()
     logInfo("ApplicationAttemptId: " + appAttemptId)
     appAttemptId
   }
   
-  private def registerWithResourceManager(): AMRMProtocol = {
-    val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
-      YarnConfiguration.RM_SCHEDULER_ADDRESS,
-      YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
-    logInfo("Connecting to ResourceManager at " + rmAddress)
-    rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
-  }
-  
   private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
     logInfo("Registering the ApplicationMaster")
-    val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
-      .asInstanceOf[RegisterApplicationMasterRequest]
-    appMasterRequest.setApplicationAttemptId(appAttemptId)
-    // Setting this to master host,port - so that the ApplicationReport at client has some
-    // sensible info. 
-    // Users can then monitor stderr/stdout on that node if required.
-    appMasterRequest.setHost(Utils.localHostName())
-    appMasterRequest.setRpcPort(0)
-    appMasterRequest.setTrackingUrl(uiAddress)
-    resourceManager.registerApplicationMaster(appMasterRequest)
+    amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
   }
   
   private def waitForSparkMaster() {
-    logInfo("Waiting for spark driver to be reachable.")
+    logInfo("Waiting for Spark driver to be reachable.")
     var driverUp = false
     var tries = 0
     val numTries = System.getProperty("spark.yarn.applicationMaster.waitTries", "10").toInt
-    while(!driverUp && tries < numTries) {
+    while (!driverUp && tries < numTries) {
       val driverHost = System.getProperty("spark.driver.host")
       val driverPort = System.getProperty("spark.driver.port")
       try {
@@ -176,8 +149,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
         driverUp = true
       } catch {
         case e: Exception => {
-          logWarning("Failed to connect to driver at %s:%s, retrying ...").
-            format(driverHost, driverPort)
+          logWarning("Failed to connect to driver at %s:%s, retrying ...".
+            format(driverHost, driverPort))
           Thread.sleep(100)
           tries = tries + 1
         }
@@ -218,44 +191,44 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
     t
   }
 
-  // this need to happen before allocateWorkers
+  // This need to happen before allocateWorkers()
   private def waitForSparkContextInitialized() {
-    logInfo("Waiting for spark context initialization")
+    logInfo("Waiting for Spark context initialization")
     try {
       var sparkContext: SparkContext = null
       ApplicationMaster.sparkContextRef.synchronized {
-        var count = 0
+        var numTries = 0
         val waitTime = 10000L
-        val numTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt
-        while (ApplicationMaster.sparkContextRef.get() == null && count < numTries) {
-          logInfo("Waiting for spark context initialization ... " + count)
-          count = count + 1
+        val maxNumTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt
+        while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) {
+          logInfo("Waiting for Spark context initialization ... " + numTries)
+          numTries = numTries + 1
           ApplicationMaster.sparkContextRef.wait(waitTime)
         }
         sparkContext = ApplicationMaster.sparkContextRef.get()
-        assert(sparkContext != null || count >= numTries)
+        assert(sparkContext != null || numTries >= maxNumTries)
 
-        if (null != sparkContext) {
+        if (sparkContext != null) {
           uiAddress = sparkContext.ui.appUIAddress
           this.yarnAllocator = YarnAllocationHandler.newAllocator(
             yarnConf,
-            resourceManager,
+            amClient,
             appAttemptId,
             args, 
-            sparkContext.preferredNodeLocationData) 
+            sparkContext.preferredNodeLocationData)
         } else {
-          logWarning("Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".
-            format(count * waitTime, numTries))
+          logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d".
+            format(numTries * waitTime, maxNumTries))
           this.yarnAllocator = YarnAllocationHandler.newAllocator(
             yarnConf,
-            resourceManager,
+            amClient,
             appAttemptId,
             args)
         }
       }
     } finally {
-      // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
-      // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+      // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT :
+      // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks.
       ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
     }
   }
@@ -266,15 +239,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
       // Wait until all containers have finished
       // TODO: This is a bit ugly. Can we make it nicer?
       // TODO: Handle container failure
-
-      // Exists the loop if the user thread exits.
+      yarnAllocator.addResourceRequests(args.numWorkers)
+      // Exits the loop if the user thread exits.
       while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) {
         if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
           finishApplicationMaster(FinalApplicationStatus.FAILED,
             "max number of worker failures reached")
         }
-        yarnAllocator.allocateContainers(
-          math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+        yarnAllocator.allocateResources()
         ApplicationMaster.incrementAllocatorLoop(1)
         Thread.sleep(100)
       }
@@ -287,7 +259,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
 
     // Launch a progress reporter thread, else the app will get killed after expiration
     // (def: 10mins) timeout.
-    // TODO(harvey): Verify the timeout
     if (userThread.isAlive) {
       // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
       val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
@@ -313,13 +284,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
             finishApplicationMaster(FinalApplicationStatus.FAILED,
               "max number of worker failures reached")
           }
-          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning -
+            yarnAllocator.getNumPendingAllocate
           if (missingWorkerCount > 0) {
             logInfo("Allocating %d containers to make up for (potentially) lost containers".
               format(missingWorkerCount))
-            yarnAllocator.allocateContainers(missingWorkerCount)
+            yarnAllocator.addResourceRequests(missingWorkerCount)
           }
-          else sendProgress()
+          sendProgress()
           Thread.sleep(sleepTime)
         }
       }
@@ -333,8 +305,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
 
   private def sendProgress() {
     logDebug("Sending progress")
-    // Simulated with an allocate request with no nodes requested ...
-    yarnAllocator.allocateContainers(0)
+    // Simulated with an allocate request with no nodes requested.
+    yarnAllocator.allocateResources()
   }
 
   /*
@@ -361,14 +333,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
     }
 
     logInfo("finishApplicationMaster with " + status)
-    val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
-      .asInstanceOf[FinishApplicationMasterRequest]
-    finishReq.setAppAttemptId(appAttemptId)
-    finishReq.setFinishApplicationStatus(status)
-    finishReq.setDiagnostics(diagnostics)
-    // Set tracking url to empty since we don't have a history server.
-    finishReq.setTrackingUrl("")
-    resourceManager.finishApplicationMaster(finishReq)
+    // Set tracking URL to empty since we don't have a history server.
+    amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */)
   }
 
   /**
@@ -412,6 +378,14 @@ object ApplicationMaster {
   // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
   // optimal as more containers are available. Might need to handle this better.
   private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+
+  private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+  val sparkContextRef: AtomicReference[SparkContext] =
+    new AtomicReference[SparkContext](null /* initialValue */)
+
+  val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
   def incrementAllocatorLoop(by: Int) {
     val count = yarnAllocatorLoop.getAndAdd(by)
     if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
@@ -422,16 +396,11 @@ object ApplicationMaster {
     }
   }
 
-  private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
-
   def register(master: ApplicationMaster) {
     applicationMasters.add(master)
   }
 
-  val sparkContextRef: AtomicReference[SparkContext] =
-    new AtomicReference[SparkContext](null /* initialValue */)
-  val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
-
+  // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm...
   def sparkContextInitialized(sc: SparkContext): Boolean = {
     var modified = false
     sparkContextRef.synchronized {
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 86310f32d5b73c5792a5332e831f8441b621eeaf..ee900867292fe6da43e233fe1022cdeae2b6c170 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.api._
 import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
 import org.apache.hadoop.yarn.api.protocolrecords._
 import org.apache.hadoop.yarn.api.records._
-import org.apache.hadoop.yarn.client.YarnClientImpl
+import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import org.apache.hadoop.yarn.ipc.YarnRPC
 import org.apache.hadoop.yarn.util.{Apps, Records}
@@ -45,10 +45,13 @@ import org.apache.spark.util.Utils
 import org.apache.spark.deploy.SparkHadoopUtil
 
 
+/**
+ * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The
+ * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster,
+ * which will launch a Spark master process and negotiate resources throughout its duration.
+ */
 class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
 
-  def this(args: ClientArguments) = this(new Configuration(), args)
-
   var rpc: YarnRPC = YarnRPC.create(conf)
   val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
   val credentials = UserGroupInformation.getCurrentUser().getCredentials()
@@ -56,48 +59,68 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
   private val distCacheMgr = new ClientDistributedCacheManager()
 
   // Staging directory is private! -> rwx--------
-  val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
+  val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short)
   // App files are world-wide readable and owner writable -> rw-r--r--
-  val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) 
+  val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short)
+
+  def this(args: ClientArguments) = this(new Configuration(), args)
 
   def run() {
     validateArgs()
 
+    // Initialize and start the client service.
     init(yarnConf)
     start()
+
+    // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers).
     logClusterResourceDetails()
 
-    val newApp = super.getNewApplication()
-    val appId = newApp.getApplicationId()
+    // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM)
+    // interface).
 
-    verifyClusterResources(newApp)
-    val appContext = createApplicationSubmissionContext(appId)
+    // Get a new client application.
+    val newApp = super.createApplication()
+    val newAppResponse = newApp.getNewApplicationResponse()
+    val appId = newAppResponse.getApplicationId()
+
+    verifyClusterResources(newAppResponse)
+
+    // Set up resource and environment variables.
     val appStagingDir = getAppStagingDir(appId)
     val localResources = prepareLocalResources(appStagingDir)
-    val env = setupLaunchEnv(localResources, appStagingDir)
-    val amContainer = createContainerLaunchContext(newApp, localResources, env)
+    val launchEnv = setupLaunchEnv(localResources, appStagingDir)
+    val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv)
 
+    // Set up an application submission context.
+    val appContext = newApp.getApplicationSubmissionContext()
+    appContext.setApplicationName(args.appName)
     appContext.setQueue(args.amQueue)
     appContext.setAMContainerSpec(amContainer)
-    appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
 
-    submitApp(appContext)
+    // Memory for the ApplicationMaster.
+    val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+    memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+    appContext.setResource(memoryResource)
 
+    // Finally, submit and monitor the application.
+    submitApp(appContext)
     monitorApplication(appId)
+
     System.exit(0)
   }
 
+  // TODO(harvey): This could just go in ClientArguments.
   def validateArgs() = {
     Map(
       (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!",
       (args.userJar == null) -> "Error: You must specify a user jar!",
       (args.userClass == null) -> "Error: You must specify a user class!",
       (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!",
-      (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> "Error: AM memory size must be +
-        greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD,
-      (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> "Error: Worker memory size +
-        must be greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString
-    .foreach { case(cond, errStr) => 
+      (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" +
+        "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD),
+      (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" +
+        "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString)
+    ).foreach { case(cond, errStr) =>
       if (cond) {
         logError(errStr)
         args.printUsageAndExit(1)
@@ -111,17 +134,17 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
 
   def logClusterResourceDetails() {
     val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
-    logInfo("Got Cluster metric info from ASM, numNodeManagers = " +
+    logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
       clusterMetrics.getNumNodeManagers)
 
     val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
-    logInfo("""Queue info ... queueName = %s, queueCurrentCapacity = %s, queueMaxCapacity = %s,
+    logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s,
       queueApplicationCount = %s, queueChildQueueCount = %s""".format(
         queueInfo.getQueueName,
         queueInfo.getCurrentCapacity,
         queueInfo.getMaximumCapacity,
         queueInfo.getApplications.size,
-        queueInfo.getChildQueues.size)
+        queueInfo.getChildQueues.size))
   }
 
   def verifyClusterResources(app: GetNewApplicationResponse) = { 
@@ -130,25 +153,19 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
 
     // If we have requested more then the clusters max for a single resource then exit.
     if (args.workerMemory > maxMem) {
-      logError("the worker size is to large to run on this cluster " + args.workerMemory)
+      logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.".
+        format(args.workerMemory, maxMem))
       System.exit(1)
     }
     val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
     if (amMem > maxMem) {
-      logError("AM size is to large to run on this cluster "  + amMem)
+      logError("Required AM memory (%d) is above the max threshold (%d) of this cluster".
+        format(args.amMemory, maxMem))
       System.exit(1)
     }
 
     // We could add checks to make sure the entire cluster has enough resources but that involves
-    // getting all the node reports and computing ourselves 
-  }
-
-  def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
-    logInfo("Setting up application submission context for ASM")
-    val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
-    appContext.setApplicationId(appId)
-    appContext.setApplicationName(args.appName)
-    return appContext
+    // getting all the node reports and computing ourselves.
   }
 
   /** See if two file systems are the same or not. */
@@ -213,7 +230,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
   def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
     logInfo("Preparing Local resources")
     // Upload Spark and the application JAR to the remote file system if necessary. Add them as
-    // local resources to the AM.
+    // local resources to the application master.
     val fs = FileSystem.get(conf)
 
     val delegTokenRenewer = Master.getMasterPrincipal(conf)
@@ -230,18 +247,20 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       val dstFs = dst.getFileSystem(conf)
       dstFs.addDelegationTokens(delegTokenRenewer, credentials)
     }
+
     val localResources = HashMap[String, LocalResource]()
     FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
 
     val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
 
-    Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, 
-      Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF"))
-    .foreach { case(destName, _localPath) =>
+    Map(
+      Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar,
+      Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")
+    ).foreach { case(destName, _localPath) =>
       val localPath: String = if (_localPath != null) _localPath.trim() else ""
       if (! localPath.isEmpty()) {
         var localURI = new URI(localPath)
-        // if not specified assume these are in the local filesystem to keep behavior like Hadoop
+        // If not specified assume these are in the local filesystem to keep behavior like Hadoop
         if (localURI.getScheme() == null) {
           localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString)
         }
@@ -252,19 +271,21 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       }
     }
 
-    // handle any add jars
+    // Handle jars local to the ApplicationMaster.
     if ((args.addJars != null) && (!args.addJars.isEmpty())){
       args.addJars.split(',').foreach { case file: String =>
         val localURI = new URI(file.trim())
         val localPath = new Path(localURI)
         val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
         val destPath = copyRemoteFile(dst, localPath, replication)
+        // Only add the resource to the Spark ApplicationMaster.
+        val appMasterOnly = true
         distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, 
-          linkname, statCache, true)
+          linkname, statCache, appMasterOnly)
       }
     }
 
-    // handle any distributed cache files
+    // Handle any distributed cache files
     if ((args.files != null) && (!args.files.isEmpty())){
       args.files.split(',').foreach { case file: String =>
         val localURI = new URI(file.trim())
@@ -276,7 +297,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       }
     }
 
-    // handle any distributed cache archives
+    // Handle any distributed cache archives
     if ((args.archives != null) && (!args.archives.isEmpty())) {
       args.archives.split(',').foreach { case file:String =>
         val localURI = new URI(file.trim())
@@ -289,7 +310,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     }
 
     UserGroupInformation.getCurrentUser().addCredentials(credentials)
-    return localResources
+    localResources
   }
 
   def setupLaunchEnv(
@@ -311,8 +332,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     // Allow users to specify some environment variables.
     Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
 
-    // Add each SPARK-* key to the environment.
+    // Add each SPARK_* key to the environment.
     System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+
     env
   }
 
@@ -335,33 +357,32 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     amContainer.setLocalResources(localResources)
     amContainer.setEnvironment(env)
 
-    val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
-
-    // TODO(harvey): This can probably be a val.
-    var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
-      ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) -
-        YarnAllocationHandler.MEMORY_OVERHEAD)
+    // TODO: Need a replacement for the following code to fix -Xmx?
+    // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+    // var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+    //  ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) -
+    //    YarnAllocationHandler.MEMORY_OVERHEAD)
 
     // Extra options for the JVM
     var JAVA_OPTS = ""
 
-    // Add Xmx for am memory
-    JAVA_OPTS += "-Xmx" + amMemory + "m "
+    // Add Xmx for AM memory
+    JAVA_OPTS += "-Xmx" + args.amMemory + "m"
 
-    JAVA_OPTS += " -Djava.io.tmpdir=" + 
-      new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
+    val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+    JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir
 
-    // Commenting it out for now - so that people can refer to the properties if required. Remove
-    // it once cpuset version is pushed out. The context is, default gc for server class machines
-    // end up using all cores to do gc - hence if there are multiple containers in same node,
-    // spark gc effects all other containers performance (which can also be other spark containers)
-    // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
-    // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+    // TODO: Remove once cpuset version is pushed out.
+    // The context is, default gc for server class machines ends up using all cores to do gc -
+    // hence if there are multiple containers in same node, Spark GC affects all other containers'
+    // performance (which can be that of other Spark containers)
+    // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
+    // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
     // of cores on a node.
     val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") &&
       java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))
     if (useConcurrentAndIncrementalGC) {
-      // In our expts, using (default) throughput collector has severe perf ramnifications in
+      // In our expts, using (default) throughput collector has severe perf ramifications in
       // multi-tenant machines
       JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
       JAVA_OPTS += " -XX:+CMSIncrementalMode "
@@ -371,7 +392,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     }
 
     if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
-      JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+      JAVA_OPTS += " " + env("SPARK_JAVA_OPTS")
     }
 
     // Command for the ApplicationMaster
@@ -381,7 +402,8 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
     }
 
-    val commands = List[String](javaCommand + 
+    val commands = List[String](
+      javaCommand + 
       " -server " +
       JAVA_OPTS +
       " org.apache.spark.deploy.yarn.ApplicationMaster" +
@@ -393,18 +415,14 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       " --num-workers " + args.numWorkers +
       " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
       " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
-    logInfo("Command for the ApplicationMaster: " + commands(0))
-    amContainer.setCommands(commands)
 
-    val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
-    // Memory for the ApplicationMaster.
-    capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
-    amContainer.setResource(capability)
+    logInfo("Command for starting the Spark ApplicationMaster: " + commands(0))
+    amContainer.setCommands(commands)
 
     // Setup security tokens.
     val dob = new DataOutputBuffer()
     credentials.writeTokenStorageToStream(dob)
-    amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData()))
+    amContainer.setTokens(ByteBuffer.wrap(dob.getData()))
 
     amContainer
   }
@@ -423,7 +441,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
       logInfo("Application report from ASM: \n" +
         "\t application identifier: " + appId.toString() + "\n" +
         "\t appId: " + appId.getId() + "\n" +
-        "\t clientToken: " + report.getClientToken() + "\n" +
+        "\t clientToAMToken: " + report.getClientToAMToken() + "\n" +
         "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
         "\t appMasterHost: " + report.getHost() + "\n" +
         "\t appQueue: " + report.getQueue() + "\n" +
@@ -454,12 +472,13 @@ object Client {
 
   def main(argStrings: Array[String]) {
     // Set an env variable indicating we are running in YARN mode.
-    // Note that anything with SPARK prefix gets propagated to all (remote) processes
+    // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes -
+    // see Client#setupLaunchEnv().
     System.setProperty("SPARK_YARN_MODE", "true")
 
     val args = new ClientArguments(argStrings)
 
-    new Client(args).run
+    (new Client(args)).run()
   }
 
   // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 852dbd7dabf66e391c4e64584670cdf0a74a853f..6d3c95867e389bf24bdd8c33235e7e53aeb15e90 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -17,12 +17,14 @@
 
 package org.apache.spark.deploy.yarn
 
-import org.apache.spark.util.MemoryParam
-import org.apache.spark.util.IntParam
-import collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
 import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo}
+import org.apache.spark.util.IntParam
+import org.apache.spark.util.MemoryParam
+
 
-// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
+// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
 class ClientArguments(val args: Array[String]) {
   var addJars: String = null
   var files: String = null
@@ -30,14 +32,16 @@ class ClientArguments(val args: Array[String]) {
   var userJar: String = null
   var userClass: String = null
   var userArgs: Seq[String] = Seq[String]()
-  var workerMemory = 1024
+  var workerMemory = 1024 // MB
   var workerCores = 1
   var numWorkers = 2
   var amQueue = System.getProperty("QUEUE", "default")
-  var amMemory: Int = 512
+  var amMemory: Int = 512 // MB
   var appName: String = "Spark"
   // TODO
   var inputFormatInfo: List[InputFormatInfo] = null
+  // TODO(harvey)
+  var priority = 0
 
   parseArgs(args.toList)
 
@@ -47,8 +51,7 @@ class ClientArguments(val args: Array[String]) {
 
     var args = inputArgs
 
-    while (! args.isEmpty) {
-
+    while (!args.isEmpty) {
       args match {
         case ("--jar") :: value :: tail =>
           userJar = value
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index 6a90cc51cfbaf71ef5683703dd84fa6d5b872ccb..9f5523c4b97a8811c14a19f9e34c6cd1e32028e7 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -32,10 +32,12 @@ import org.apache.hadoop.security.UserGroupInformation
 import org.apache.hadoop.yarn.api._
 import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
 import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils
 import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.api.NMClient
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import org.apache.hadoop.yarn.ipc.YarnRPC
-import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils}
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
 
 import org.apache.spark.Logging
 
@@ -51,12 +53,14 @@ class WorkerRunnable(
   extends Runnable with Logging {
 
   var rpc: YarnRPC = YarnRPC.create(conf)
-  var cm: ContainerManager = null
+  var nmClient: NMClient = _
   val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
 
   def run = {
     logInfo("Starting Worker Container")
-    cm = connectToCM
+    nmClient = NMClient.createNMClient()
+    nmClient.init(yarnConf)
+    nmClient.start()
     startContainer
   }
 
@@ -66,8 +70,6 @@ class WorkerRunnable(
     val ctx = Records.newRecord(classOf[ContainerLaunchContext])
       .asInstanceOf[ContainerLaunchContext]
 
-    ctx.setContainerId(container.getId())
-    ctx.setResource(container.getResource())
     val localResources = prepareLocalResources
     ctx.setLocalResources(localResources)
 
@@ -111,12 +113,10 @@ class WorkerRunnable(
     }
 */
 
-    ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
-
     val credentials = UserGroupInformation.getCurrentUser().getCredentials()
     val dob = new DataOutputBuffer()
     credentials.writeTokenStorageToStream(dob)
-    ctx.setContainerTokens(ByteBuffer.wrap(dob.getData()))
+    ctx.setTokens(ByteBuffer.wrap(dob.getData()))
 
     var javaCommand = "java"
     val javaHome = System.getenv("JAVA_HOME")
@@ -144,10 +144,7 @@ class WorkerRunnable(
     ctx.setCommands(commands)
 
     // Send the start request to the ContainerManager
-    val startReq = Records.newRecord(classOf[StartContainerRequest])
-    .asInstanceOf[StartContainerRequest]
-    startReq.setContainerLaunchContext(ctx)
-    cm.startContainer(startReq)
+    nmClient.startContainer(container, ctx)
   }
 
   private def setupDistributedCache(
@@ -194,7 +191,7 @@ class WorkerRunnable(
     }
 
     logInfo("Prepared Local resources " + localResources)
-    return localResources
+    localResources
   }
 
   def prepareEnvironment: HashMap[String, String] = {
@@ -206,30 +203,7 @@ class WorkerRunnable(
     Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
 
     System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
-    return env
-  }
-
-  def connectToCM: ContainerManager = {
-    val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
-    val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
-    logInfo("Connecting to ContainerManager at " + cmHostPortStr)
-
-    // Use doAs and remoteUser here so we can add the container token and not pollute the current
-    // users credentials with all of the individual container tokens
-    val user = UserGroupInformation.createRemoteUser(container.getId().toString())
-    val containerToken = container.getContainerToken()
-    if (containerToken != null) {
-      user.addToken(ProtoUtils.convertFromProtoFormat(containerToken, cmAddress))
-    }
-
-    val proxy = user
-        .doAs(new PrivilegedExceptionAction[ContainerManager] {
-          def run: ContainerManager = {
-            return rpc.getProxy(classOf[ContainerManager],
-                cmAddress, conf).asInstanceOf[ContainerManager]
-          }
-        })
-    proxy
+    env
   }
 
 }
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 6ce470e8cb6718cdfa642fa696df45d4cfb989b9..dba0f7640e67cc88bbe432ec12e218c6103244ff 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -32,11 +32,13 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu
 import org.apache.spark.util.Utils
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.yarn.api.AMRMProtocol
-import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId}
+import org.apache.hadoop.yarn.api.ApplicationMasterProtocol
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId
 import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus}
 import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest}
 import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.util.{RackResolver, Records}
 
 
@@ -56,7 +58,7 @@ object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
 // more info on how we are requesting for containers.
 private[yarn] class YarnAllocationHandler(
     val conf: Configuration,
-    val resourceManager: AMRMProtocol, 
+    val amClient: AMRMClient[ContainerRequest],
     val appAttemptId: ApplicationAttemptId,
     val maxWorkers: Int,
     val workerMemory: Int,
@@ -83,12 +85,17 @@ private[yarn] class YarnAllocationHandler(
   // Containers to be released in next request to RM
   private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
 
+  // Number of container requests that have been sent to, but not yet allocated by the
+  // ApplicationMaster.
+  private val numPendingAllocate = new AtomicInteger()
   private val numWorkersRunning = new AtomicInteger()
   // Used to generate a unique id per worker
   private val workerIdCounter = new AtomicInteger()
   private val lastResponseId = new AtomicInteger()
   private val numWorkersFailed = new AtomicInteger()
 
+  def getNumPendingAllocate: Int = numPendingAllocate.intValue
+
   def getNumWorkersRunning: Int = numWorkersRunning.intValue
 
   def getNumWorkersFailed: Int = numWorkersFailed.intValue
@@ -97,154 +104,163 @@ private[yarn] class YarnAllocationHandler(
     container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
   }
 
-  def allocateContainers(workersToRequest: Int) {
-    // We need to send the request only once from what I understand ... but for now, not modifying
-    // this much.
+  def releaseContainer(container: Container) {
+    val containerId = container.getId
+    pendingReleaseContainers.put(containerId, true)
+    amClient.releaseAssignedContainer(containerId)
+  }
+
+  def allocateResources() {
+    // We have already set the container request. Poll the ResourceManager for a response.
+    // This doubles as a heartbeat if there are no pending container requests.
+    val progressIndicator = 0.1f
+    val allocateResponse = amClient.allocate(progressIndicator)
 
-    // Keep polling the Resource Manager for containers
-    val amResp = allocateWorkerResources(workersToRequest).getAMResponse
+    val allocatedContainers = allocateResponse.getAllocatedContainers()
+    if (allocatedContainers.size > 0) {
+      var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size)
 
-    val _allocatedContainers = amResp.getAllocatedContainers()
+      if (numPendingAllocateNow < 0) {
+        numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow)
+      }
 
-    if (_allocatedContainers.size > 0) {
       logDebug("""
         Allocated containers: %d
         Current worker count: %d
-        Containers to-be-released: %d
-        pendingReleaseContainers: %s
+        Containers released: %s
+        Containers to-be-released: %s
         Cluster resources: %s
         """.format(
           allocatedContainers.size,
           numWorkersRunning.get(),
           releasedContainerList,
           pendingReleaseContainers,
-          amResp.getAvailableResources))
+          allocateResponse.getAvailableResources))
 
       val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
 
-      // Ignore if not satisfying constraints      {
-      for (container <- _allocatedContainers) {
+      for (container <- allocatedContainers) {
         if (isResourceConstraintSatisfied(container)) {
-          // allocatedContainers += container
-
+          // Add the accepted `container` to the host's list of already accepted,
+          // allocated containers
           val host = container.getNodeId.getHost
-          val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
-
-          containers += container
+          val containersForHost = hostToContainers.getOrElseUpdate(host,
+            new ArrayBuffer[Container]())
+          containersForHost += container
+        } else {
+          // Release container, since it doesn't satisfy resource constraints.
+          releaseContainer(container)
         }
-        // Add all ignored containers to released list
-        else releasedContainerList.add(container.getId())
       }
 
-      // Find the appropriate containers to use. Slightly non trivial groupBy ...
+       // Find the appropriate containers to use.
+      // TODO: Cleanup this group-by...
       val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
       val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
       val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
 
-      for (candidateHost <- hostToContainers.keySet)
-      {
+      for (candidateHost <- hostToContainers.keySet) {
         val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
         val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
 
-        var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
-        assert(remainingContainers != null)
+        val remainingContainersOpt = hostToContainers.get(candidateHost)
+        assert(remainingContainersOpt.isDefined)
+        var remainingContainers = remainingContainersOpt.get
 
-        if (requiredHostCount >= remainingContainers.size){
-          // Since we got <= required containers, add all to dataLocalContainers
+        if (requiredHostCount >= remainingContainers.size) {
+          // Since we have <= required containers, add all remaining containers to
+          // `dataLocalContainers`.
           dataLocalContainers.put(candidateHost, remainingContainers)
-          // all consumed
+          // There are no more free containers remaining.
           remainingContainers = null
-        }
-        else if (requiredHostCount > 0) {
+        } else if (requiredHostCount > 0) {
           // Container list has more containers than we need for data locality.
-          // Split into two : data local container count of (remainingContainers.size -
-          // requiredHostCount) and rest as remainingContainer
+          // Split the list into two: one based on the data local container count,
+          // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+          // containers.
           val (dataLocal, remaining) = remainingContainers.splitAt(
             remainingContainers.size - requiredHostCount)
           dataLocalContainers.put(candidateHost, dataLocal)
-          // remainingContainers = remaining
 
-          // yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
-          // add remaining to release list. If we have insufficient containers, next allocation 
-          // cycle will reallocate (but wont treat it as data local)
-          for (container <- remaining) releasedContainerList.add(container.getId())
+          // Invariant: remainingContainers == remaining
+
+          // YARN has a nasty habit of allocating a ton of containers on a host - discourage this.
+          // Add each container in `remaining` to list of containers to release. If we have an
+          // insufficient number of containers, then the next allocation cycle will reallocate
+          // (but won't treat it as data local).
+          // TODO(harvey): Rephrase this comment some more.
+          for (container <- remaining) releaseContainer(container)
           remainingContainers = null
         }
 
-        // Now rack local
-        if (remainingContainers != null){
+        // For rack local containers
+        if (remainingContainers != null) {
           val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
-
-          if (rack != null){
+          if (rack != null) {
             val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
-            val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - 
-              rackLocalContainers.get(rack).getOrElse(List()).size
-
+            val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+              rackLocalContainers.getOrElse(rack, List()).size
 
-            if (requiredRackCount >= remainingContainers.size){
-              // Add all to dataLocalContainers
+            if (requiredRackCount >= remainingContainers.size) {
+              // Add all remaining containers to to `dataLocalContainers`.
               dataLocalContainers.put(rack, remainingContainers)
-              // All consumed
               remainingContainers = null
-            }
-            else if (requiredRackCount > 0) {
-              // container list has more containers than we need for data locality.
-              // Split into two : data local container count of (remainingContainers.size -
-              // requiredRackCount) and rest as remainingContainer
+            } else if (requiredRackCount > 0) {
+              // Container list has more containers that we need for data locality.
+              // Split the list into two: one based on the data local container count,
+              // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+              // containers.
               val (rackLocal, remaining) = remainingContainers.splitAt(
                 remainingContainers.size - requiredRackCount)
               val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
                 new ArrayBuffer[Container]())
 
               existingRackLocal ++= rackLocal
+
               remainingContainers = remaining
             }
           }
         }
 
-        // If still not consumed, then it is off rack host - add to that list.
-        if (remainingContainers != null){
+        if (remainingContainers != null) {
+          // Not all containers have been consumed - add them to the list of off-rack containers.
           offRackContainers.put(candidateHost, remainingContainers)
         }
       }
 
-      // Now that we have split the containers into various groups, go through them in order : 
-      // first host local, then rack local and then off rack (everything else).
-      // Note that the list we create below tries to ensure that not all containers end up within a
-      // host if there are sufficiently large number of hosts/containers.
-
-      val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
-
-      // Run each of the allocated containers
-      for (container <- allocatedContainers) {
+      // Now that we have split the containers into various groups, go through them in order:
+      // first host-local, then rack-local, and finally off-rack.
+      // Note that the list we create below tries to ensure that not all containers end up within
+      // a host if there is a sufficiently large number of hosts/containers.
+      val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
+      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+      // Run each of the allocated containers.
+      for (container <- allocatedContainersToProcess) {
         val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
         val workerHostname = container.getNodeId.getHost
         val containerId = container.getId
 
-        assert(
-          container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+        val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+        assert(container.getResource.getMemory >= workerMemoryOverhead)
 
         if (numWorkersRunningNow > maxWorkers) {
-          logInfo("""Ignoring container %d at host %s, since we already have the required number of
+          logInfo("""Ignoring container %s at host %s, since we already have the required number of
             containers for it.""".format(containerId, workerHostname))
-          releasedContainerList.add(containerId)
-          // reset counter back to old value.
+          releaseContainer(container)
           numWorkersRunning.decrementAndGet()
-        }
-        else {
-          // Deallocate + allocate can result in reusing id's wrongly - so use a different counter
-          // (workerIdCounter)
+        } else {
           val workerId = workerIdCounter.incrementAndGet().toString
           val driverUrl = "akka://spark@%s:%s/user/%s".format(
-            System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+            System.getProperty("spark.driver.host"),
+            System.getProperty("spark.driver.port"),
             CoarseGrainedSchedulerBackend.ACTOR_NAME)
 
-          logInfo("launching container on " + containerId + " host " + workerHostname)
-          // Just to be safe, simply remove it from pendingReleaseContainers.
-          // Should not be there, but ..
+          logInfo("Launching container %s for on host %s".format(containerId, workerHostname))
+
+          // To be safe, remove the container from `pendingReleaseContainers`.
           pendingReleaseContainers.remove(containerId)
 
           val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
@@ -254,45 +270,52 @@ private[yarn] class YarnAllocationHandler(
 
             containerSet += containerId
             allocatedContainerToHostMap.put(containerId, workerHostname)
+
             if (rack != null) {
               allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
             }
           }
-
-          new Thread(
-            new WorkerRunnable(container, conf, driverUrl, workerId,
-              workerHostname, workerMemory, workerCores)
-          ).start()
+          logInfo("Launching WorkerRunnable. driverUrl: %s,  workerHostname: %s".format(driverUrl, workerHostname))
+          val workerRunnable = new WorkerRunnable(
+            container,
+            conf,
+            driverUrl,
+            workerId,
+            workerHostname,
+            workerMemory,
+            workerCores)
+          new Thread(workerRunnable).start()
         }
       }
       logDebug("""
-        Finished processing %d completed containers.
+        Finished allocating %s containers (from %s originally).
         Current number of workers running: %d,
         releasedContainerList: %s,
         pendingReleaseContainers: %s
         """.format(
-          completedContainers.size,
+          allocatedContainersToProcess,
+          allocatedContainers,
           numWorkersRunning.get(),
           releasedContainerList,
           pendingReleaseContainers))
     }
 
+    val completedContainers = allocateResponse.getCompletedContainersStatuses()
+    if (completedContainers.size > 0) {
+      logDebug("Completed %d containers".format(completedContainers.size))
 
-    val completedContainers = amResp.getCompletedContainersStatuses()
-    if (completedContainers.size > 0){
-      logDebug("Completed %d containers, to-be-released: %s".format(
-        completedContainers.size, releasedContainerList))
-      for (completedContainer <- completedContainers){
+      for (completedContainer <- completedContainers) {
         val containerId = completedContainer.getContainerId
 
-        // Was this released by us ? If yes, then simply remove from containerSet and move on.
         if (pendingReleaseContainers.containsKey(containerId)) {
+          // YarnAllocationHandler already marked the container for release, so remove it from
+          // `pendingReleaseContainers`.
           pendingReleaseContainers.remove(containerId)
-        }
-        else {
-          // Simply decrement count - next iteration of ReporterThread will take care of allocating.
+        } else {
+          // Decrement the number of workers running. The next iteration of the ApplicationMaster's
+          // reporting thread will take care of allocating.
           numWorkersRunning.decrementAndGet()
-          logInfo("Completed container %d (state: %s, http address: %s, exit status: %s)".format(
+          logInfo("Completed container %s (state: %s, exit status: %s)".format(
             containerId,
             completedContainer.getState,
             completedContainer.getExitStatus()))
@@ -307,24 +330,32 @@ private[yarn] class YarnAllocationHandler(
 
         allocatedHostToContainersMap.synchronized {
           if (allocatedContainerToHostMap.containsKey(containerId)) {
-            val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
-            assert (host != null)
-
-            val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
-            assert (containerSet != null)
-
-            containerSet -= containerId
-            if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
-            else allocatedHostToContainersMap.update(host, containerSet)
+            val hostOpt = allocatedContainerToHostMap.get(containerId)
+            assert(hostOpt.isDefined)
+            val host = hostOpt.get
+
+            val containerSetOpt = allocatedHostToContainersMap.get(host)
+            assert(containerSetOpt.isDefined)
+            val containerSet = containerSetOpt.get
+
+            containerSet.remove(containerId)
+            if (containerSet.isEmpty) {
+              allocatedHostToContainersMap.remove(host)
+            } else {
+              allocatedHostToContainersMap.update(host, containerSet)
+            }
 
-            allocatedContainerToHostMap -= containerId
+            allocatedContainerToHostMap.remove(containerId)
 
-            // Doing this within locked context, sigh ... move to outside ?
+            // TODO: Move this part outside the synchronized block?
             val rack = YarnAllocationHandler.lookupRack(conf, host)
             if (rack != null) {
               val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
-              if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
-              else allocatedRackCount.remove(rack)
+              if (rackCount > 0) {
+                allocatedRackCount.put(rack, rackCount)
+              } else {
+                allocatedRackCount.remove(rack)
+              }
             }
           }
         }
@@ -342,32 +373,34 @@ private[yarn] class YarnAllocationHandler(
     }
   }
 
-  def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
-    // First generate modified racks and new set of hosts under it : then issue requests
+  def createRackResourceRequests(
+      hostContainers: ArrayBuffer[ContainerRequest]
+    ): ArrayBuffer[ContainerRequest] = {
+    // Generate modified racks and new set of hosts under it before issuing requests.
     val rackToCounts = new HashMap[String, Int]()
 
-    // Within this lock - used to read/write to the rack related maps too.
     for (container <- hostContainers) {
-      val candidateHost = container.getHostName
-      val candidateNumContainers = container.getNumContainers
+      val candidateHost = container.getNodes.last
       assert(YarnAllocationHandler.ANY_HOST != candidateHost)
 
       val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
       if (rack != null) {
         var count = rackToCounts.getOrElse(rack, 0)
-        count += candidateNumContainers
+        count += 1
         rackToCounts.put(rack, count)
       }
     }
 
-    val requestedContainers: ArrayBuffer[ResourceRequest] = 
-      new ArrayBuffer[ResourceRequest](rackToCounts.size)
-    for ((rack, count) <- rackToCounts){
-      requestedContainers += 
-        createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
+    val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size)
+    for ((rack, count) <- rackToCounts) {
+      requestedContainers ++= createResourceRequests(
+        AllocationType.RACK,
+        rack,
+        count,
+        YarnAllocationHandler.PRIORITY)
     }
 
-    requestedContainers.toList
+    requestedContainers
   }
 
   def allocatedContainersOnHost(host: String): Int = {
@@ -386,147 +419,128 @@ private[yarn] class YarnAllocationHandler(
     retval
   }
 
-  private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
-
-    var resourceRequests: List[ResourceRequest] = null
-
-      // default.
-    if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
-      logDebug("numWorkers: " + numWorkers + ", host preferences: " + preferredHostToCount.isEmpty)
-      resourceRequests = List(
-        createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
-    }
-    else {
-      // request for all hosts in preferred nodes and for numWorkers - 
-      // candidates.size, request by default allocation policy.
-      val hostContainerRequests: ArrayBuffer[ResourceRequest] = 
-        new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
-      for ((candidateHost, candidateCount) <- preferredHostToCount) {
-        val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
-
-        if (requiredCount > 0) {
-          hostContainerRequests += createResourceRequest(
-            AllocationType.HOST,
-            candidateHost,
-            requiredCount,
-            YarnAllocationHandler.PRIORITY)
+  def addResourceRequests(numWorkers: Int) {
+    val containerRequests: List[ContainerRequest] =
+      if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+        logDebug("numWorkers: " + numWorkers + ", host preferences: " +
+          preferredHostToCount.isEmpty)
+        createResourceRequests(
+          AllocationType.ANY,
+          resource = null,
+          numWorkers,
+          YarnAllocationHandler.PRIORITY).toList
+      } else {
+        // Request for all hosts in preferred nodes and for numWorkers - 
+        // candidates.size, request by default allocation policy.
+        val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size)
+        for ((candidateHost, candidateCount) <- preferredHostToCount) {
+          val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+          if (requiredCount > 0) {
+            hostContainerRequests ++= createResourceRequests(
+              AllocationType.HOST,
+              candidateHost,
+              requiredCount,
+              YarnAllocationHandler.PRIORITY)
+          }
         }
+        val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests(
+          hostContainerRequests).toList
+
+        val anyContainerRequests = createResourceRequests(
+          AllocationType.ANY,
+          resource = null,
+          numWorkers,
+          YarnAllocationHandler.PRIORITY)
+
+        val containerRequestBuffer = new ArrayBuffer[ContainerRequest](
+          hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size)
+
+        containerRequestBuffer ++= hostContainerRequests
+        containerRequestBuffer ++= rackContainerRequests
+        containerRequestBuffer ++= anyContainerRequests
+        containerRequestBuffer.toList
       }
-      val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(
-        hostContainerRequests.toList)
 
-      val anyContainerRequests: ResourceRequest = createResourceRequest(
-        AllocationType.ANY,
-        resource = null,
-        numWorkers,
-        YarnAllocationHandler.PRIORITY)
-
-      val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest](
-        hostContainerRequests.size() + rackContainerRequests.size() + 1)
-
-      containerRequests ++= hostContainerRequests
-      containerRequests ++= rackContainerRequests
-      containerRequests += anyContainerRequests
-
-      resourceRequests = containerRequests.toList
+    for (request <- containerRequests) {
+      amClient.addContainerRequest(request)
     }
 
-    val req = Records.newRecord(classOf[AllocateRequest])
-    req.setResponseId(lastResponseId.incrementAndGet)
-    req.setApplicationAttemptId(appAttemptId)
-
-    req.addAllAsks(resourceRequests)
-
-    val releasedContainerList = createReleasedContainerList()
-    req.addAllReleases(releasedContainerList)
-
     if (numWorkers > 0) {
-      logInfo("Allocating %d worker containers with %d of memory each.").format(numWorkers,
-        workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
-    }
-    else {
-      logDebug("Empty allocation req ..  release : " + releasedContainerList)
+      numPendingAllocate.addAndGet(numWorkers)
+      logInfo("Will Allocate %d worker containers, each with %d memory".format(
+        numWorkers,
+        (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)))
+    } else {
+      logDebug("Empty allocation request ...")
     }
 
-    for (request <- resourceRequests) {
-      logInfo("ResourceRequest (host : %s, num containers: %d, priority = %d , capability : %s)").
-        format(
-          request.getHostName,
-          request.getNumContainers,
-          request.getPriority,
-          request.getCapability)
+    for (request <- containerRequests) {
+      val nodes = request.getNodes
+      var hostStr = if (nodes == null || nodes.isEmpty) {
+        "Any"
+      } else {
+        nodes.last
+      }
+      logInfo("Container request (host: %s, priority: %s, capability: %s".format(
+        hostStr,
+        request.getPriority().getPriority,
+        request.getCapability))
     }
-    resourceManager.allocate(req)
   }
 
+  private def createResourceRequests(
+      requestType: AllocationType.AllocationType,
+      resource: String,
+      numWorkers: Int,
+      priority: Int
+    ): ArrayBuffer[ContainerRequest] = {
 
-  private def createResourceRequest(
-    requestType: AllocationType.AllocationType, 
-    resource:String,
-    numWorkers: Int,
-    priority: Int): ResourceRequest = {
-
-    // If hostname specified, we need atleast two requests - node local and rack local.
-    // There must be a third request - which is ANY : that will be specially handled.
+    // If hostname is specified, then we need at least two requests - node local and rack local.
+    // There must be a third request, which is ANY. That will be specially handled.
     requestType match {
       case AllocationType.HOST => {
         assert(YarnAllocationHandler.ANY_HOST != resource)
         val hostname = resource
-        val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
+        val nodeLocal = constructContainerRequests(
+          Array(hostname),
+          racks = null,
+          numWorkers,
+          priority)
 
-        // Add to host->rack mapping
+        // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler.
         YarnAllocationHandler.populateRackInfo(conf, hostname)
-
         nodeLocal
       }
       case AllocationType.RACK => {
         val rack = resource
-        createResourceRequestImpl(rack, numWorkers, priority)
+        constructContainerRequests(hosts = null, Array(rack), numWorkers, priority)
       }
-      case AllocationType.ANY => createResourceRequestImpl(
-        YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+      case AllocationType.ANY => constructContainerRequests(
+        hosts = null, racks = null, numWorkers, priority)
       case _ => throw new IllegalArgumentException(
         "Unexpected/unsupported request type: " + requestType)
     }
   }
 
-  private def createResourceRequestImpl(
-    hostname:String,
-    numWorkers: Int,
-    priority: Int): ResourceRequest = {
-
-    val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
-    val memCapability = Records.newRecord(classOf[Resource])
-    // There probably is some overhead here, let's reserve a bit more memory.
-    memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
-    rsrcRequest.setCapability(memCapability)
+  private def constructContainerRequests(
+      hosts: Array[String],
+      racks: Array[String],
+      numWorkers: Int,
+      priority: Int
+    ): ArrayBuffer[ContainerRequest] = {
 
-    val pri = Records.newRecord(classOf[Priority])
-    pri.setPriority(priority)
-    rsrcRequest.setPriority(pri)
+    val memoryResource = Records.newRecord(classOf[Resource])
+    memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
 
-    rsrcRequest.setHostName(hostname)
-
-    rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
-    rsrcRequest
-  }
+    val prioritySetting = Records.newRecord(classOf[Priority])
+    prioritySetting.setPriority(priority)
 
-  def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
-
-    val retval = new ArrayBuffer[ContainerId](1)
-    // Iterator on COW list ...
-    for (container <- releasedContainerList.iterator()){
-      retval += container
-    }
-    // Remove from the original list.
-    if (! retval.isEmpty) {
-      releasedContainerList.removeAll(retval)
-      for (v <- retval) pendingReleaseContainers.put(v, true)
-      logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + 
-        pendingReleaseContainers)
+    val requests = new ArrayBuffer[ContainerRequest]()
+    for (i <- 0 until numWorkers) {
+      requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting)
     }
-
-    retval
+    requests
   }
 }
 
@@ -537,26 +551,25 @@ object YarnAllocationHandler {
   // request types (like map/reduce in hadoop for example)
   val PRIORITY = 1
 
-  // Additional memory overhead - in mb
+  // Additional memory overhead - in mb.
   val MEMORY_OVERHEAD = 384
 
-  // Host to rack map - saved from allocation requests
-  // We are expecting this not to change.
-  // Note that it is possible for this to change : and RM will indicate that to us via update 
-  // response to allocate. But we are punting on handling that for now.
+  // Host to rack map - saved from allocation requests. We are expecting this not to change.
+  // Note that it is possible for this to change : and ResurceManager will indicate that to us via
+  // update response to allocate. But we are punting on handling that for now.
   private val hostToRack = new ConcurrentHashMap[String, String]()
   private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
 
 
   def newAllocator(
-    conf: Configuration,
-    resourceManager: AMRMProtocol,
-    appAttemptId: ApplicationAttemptId,
-    args: ApplicationMasterArguments): YarnAllocationHandler = {
-
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      args: ApplicationMasterArguments
+    ): YarnAllocationHandler = {
     new YarnAllocationHandler(
       conf,
-      resourceManager,
+      amClient,
       appAttemptId,
       args.numWorkers, 
       args.workerMemory,
@@ -566,39 +579,38 @@ object YarnAllocationHandler {
   }
 
   def newAllocator(
-    conf: Configuration,
-    resourceManager: AMRMProtocol,
-    appAttemptId: ApplicationAttemptId,
-    args: ApplicationMasterArguments,
-    map: collection.Map[String,
-    collection.Set[SplitInfo]]): YarnAllocationHandler = {
-
-    val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      args: ApplicationMasterArguments,
+      map: collection.Map[String,
+      collection.Set[SplitInfo]]
+    ): YarnAllocationHandler = {
+    val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map)
     new YarnAllocationHandler(
       conf,
-      resourceManager,
+      amClient,
       appAttemptId,
       args.numWorkers, 
       args.workerMemory,
       args.workerCores,
-      hostToCount,
-      rackToCount)
+      hostToSplitCount,
+      rackToSplitCount)
   }
 
   def newAllocator(
-    conf: Configuration,
-    resourceManager: AMRMProtocol,
-    appAttemptId: ApplicationAttemptId,
-    maxWorkers: Int,
-    workerMemory: Int,
-    workerCores: Int,
-    map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
-
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      maxWorkers: Int,
+      workerMemory: Int,
+      workerCores: Int,
+      map: collection.Map[String, collection.Set[SplitInfo]]
+    ): YarnAllocationHandler = {
     val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
-
     new YarnAllocationHandler(
       conf,
-      resourceManager,
+      amClient,
       appAttemptId,
       maxWorkers,
       workerMemory,
@@ -609,12 +621,13 @@ object YarnAllocationHandler {
 
   // A simple method to copy the split info map.
   private def generateNodeToWeight(
-    conf: Configuration,
-    input: collection.Map[String, collection.Set[SplitInfo]]) :
-  // host to count, rack to count
-  (Map[String, Int], Map[String, Int]) = {
+      conf: Configuration,
+      input: collection.Map[String, collection.Set[SplitInfo]]
+    ): (Map[String, Int], Map[String, Int]) = {
 
-    if (input == null) return (Map[String, Int](), Map[String, Int]())
+    if (input == null) {
+      return (Map[String, Int](), Map[String, Int]())
+    }
 
     val hostToCount = new HashMap[String, Int]
     val rackToCount = new HashMap[String, Int]
@@ -634,24 +647,25 @@ object YarnAllocationHandler {
   }
 
   def lookupRack(conf: Configuration, host: String): String = {
-    if (!hostToRack.contains(host)) populateRackInfo(conf, host)
+    if (!hostToRack.contains(host)) {
+      populateRackInfo(conf, host)
+    }
     hostToRack.get(host)
   }
 
   def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
-    val set = rackToHostSet.get(rack)
-    if (set == null) return None
-
-    // No better way to get a Set[String] from JSet ?
-    val convertedSet: collection.mutable.Set[String] = set
-    Some(convertedSet.toSet)
+    Option(rackToHostSet.get(rack)).map { set =>
+      val convertedSet: collection.mutable.Set[String] = set
+      // TODO: Better way to get a Set[String] from JSet.
+      convertedSet.toSet
+    }
   }
 
   def populateRackInfo(conf: Configuration, hostname: String) {
     Utils.checkHost(hostname)
 
     if (!hostToRack.containsKey(hostname)) {
-      // If there are repeated failures to resolve, all to an ignore list ?
+      // If there are repeated failures to resolve, all to an ignore list.
       val rackInfo = RackResolver.resolve(conf, hostname)
       if (rackInfo != null && rackInfo.getNetworkLocation != null) {
         val rack = rackInfo.getNetworkLocation
@@ -662,7 +676,7 @@ object YarnAllocationHandler {
         }
         rackToHostSet.get(rack).add(hostname)
 
-        // TODO(harvey): Figure out this comment...
+        // TODO(harvey): Figure out what this comment means...
         // Since RackResolver caches, we are disabling this for now ...
       } /* else {
         // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...