diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 48792a958130c42b51fd3c138d11b5793a85c86b..2a8220ff400905420543ce98ffe278dd3e6cb2e6 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -20,13 +20,15 @@ package org.apache.spark
 import java.util.Collections
 import java.util.concurrent.TimeUnit
 
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.JavaFutureAction
 import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobWaiter
 
-import scala.concurrent._
-import scala.concurrent.duration.Duration
-import scala.util.{Failure, Try}
 
 /**
  * A future for the result of an action to support cancellation. This is an extension of the
@@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] {
  * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
  * count, collect, reduce.
  */
+@DeveloperApi
 class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
   extends FutureAction[T] {
 
@@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
   }
 
   override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
-    if (!atMost.isFinite()) {
-      awaitResult()
-    } else jobWaiter.synchronized {
-      val finishTime = System.currentTimeMillis() + atMost.toMillis
-      while (!isCompleted) {
-        val time = System.currentTimeMillis()
-        if (time >= finishTime) {
-          throw new TimeoutException
-        } else {
-          jobWaiter.wait(finishTime - time)
-        }
-      }
-    }
+    jobWaiter.completionFuture.ready(atMost)
     this
   }
 
   @throws(classOf[Exception])
   override def result(atMost: Duration)(implicit permit: CanAwait): T = {
-    ready(atMost)(permit)
-    awaitResult() match {
-      case scala.util.Success(res) => res
-      case scala.util.Failure(e) => throw e
-    }
+    jobWaiter.completionFuture.ready(atMost)
+    assert(value.isDefined, "Future has not completed properly")
+    value.get.get
   }
 
   override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
-    executor.execute(new Runnable {
-      override def run() {
-        func(awaitResult())
-      }
-    })
+    jobWaiter.completionFuture onComplete {_ => func(value.get)}
   }
 
   override def isCompleted: Boolean = jobWaiter.jobFinished
 
   override def isCancelled: Boolean = _cancelled
 
-  override def value: Option[Try[T]] = {
-    if (jobWaiter.jobFinished) {
-      Some(awaitResult())
-    } else {
-      None
-    }
-  }
-
-  private def awaitResult(): Try[T] = {
-    jobWaiter.awaitResult() match {
-      case JobSucceeded => scala.util.Success(resultFunc)
-      case JobFailed(e: Exception) => scala.util.Failure(e)
-    }
-  }
+  override def value: Option[Try[T]] =
+    jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)}
 
   def jobIds: Seq[Int] = Seq(jobWaiter.jobId)
 }
 
 
+/**
+  * Handle via which a "run" function passed to a [[ComplexFutureAction]]
+  * can submit jobs for execution.
+  */
+@DeveloperApi
+trait JobSubmitter {
+  /**
+    * Submit a job for execution and return a FutureAction holding the result.
+    * This is a wrapper around the same functionality provided by SparkContext
+    * to enable cancellation.
+    */
+  def submitJob[T, U, R](
+    rdd: RDD[T],
+    processPartition: Iterator[T] => U,
+    partitions: Seq[Int],
+    resultHandler: (Int, U) => Unit,
+    resultFunc: => R): FutureAction[R]
+}
+
+
 /**
  * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
- * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
- * action thread if it is being blocked by a job.
+ * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending
+ * jobs.
  */
-class ComplexFutureAction[T] extends FutureAction[T] {
+@DeveloperApi
+class ComplexFutureAction[T](run : JobSubmitter => Future[T])
+  extends FutureAction[T] { self =>
 
-  // Pointer to the thread that is executing the action. It is set when the action is run.
-  @volatile private var thread: Thread = _
+  @volatile private var _cancelled = false
 
-  // A flag indicating whether the future has been cancelled. This is used in case the future
-  // is cancelled before the action was even run (and thus we have no thread to interrupt).
-  @volatile private var _cancelled: Boolean = false
-
-  @volatile private var jobs: Seq[Int] = Nil
+  @volatile private var subActions: List[FutureAction[_]] = Nil
 
   // A promise used to signal the future.
-  private val p = promise[T]()
+  private val p = Promise[T]().tryCompleteWith(run(jobSubmitter))
 
-  override def cancel(): Unit = this.synchronized {
+  override def cancel(): Unit = synchronized {
     _cancelled = true
-    if (thread != null) {
-      thread.interrupt()
-    }
-  }
-
-  /**
-   * Executes some action enclosed in the closure. To properly enable cancellation, the closure
-   * should use runJob implementation in this promise. See takeAsync for example.
-   */
-  def run(func: => T)(implicit executor: ExecutionContext): this.type = {
-    scala.concurrent.future {
-      thread = Thread.currentThread
-      try {
-        p.success(func)
-      } catch {
-        case e: Exception => p.failure(e)
-      } finally {
-        // This lock guarantees when calling `thread.interrupt()` in `cancel`,
-        // thread won't be set to null.
-        ComplexFutureAction.this.synchronized {
-          thread = null
-        }
-      }
-    }
-    this
+    p.tryFailure(new SparkException("Action has been cancelled"))
+    subActions.foreach(_.cancel())
   }
 
-  /**
-   * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
-   * to enable cancellation.
-   */
-  def runJob[T, U, R](
+  private def jobSubmitter = new JobSubmitter {
+    def submitJob[T, U, R](
       rdd: RDD[T],
       processPartition: Iterator[T] => U,
       partitions: Seq[Int],
       resultHandler: (Int, U) => Unit,
-      resultFunc: => R) {
-    // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
-    // command need to be in an atomic block.
-    val job = this.synchronized {
+      resultFunc: => R): FutureAction[R] = self.synchronized {
+      // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+      // command need to be in an atomic block.
       if (!isCancelled) {
-        rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+        val job = rdd.context.submitJob(
+          rdd,
+          processPartition,
+          partitions,
+          resultHandler,
+          resultFunc)
+        subActions = job :: subActions
+        job
       } else {
         throw new SparkException("Action has been cancelled")
       }
     }
-
-    this.jobs = jobs ++ job.jobIds
-
-    // Wait for the job to complete. If the action is cancelled (with an interrupt),
-    // cancel the job and stop the execution. This is not in a synchronized block because
-    // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
-    try {
-      Await.ready(job, Duration.Inf)
-    } catch {
-      case e: InterruptedException =>
-        job.cancel()
-        throw new SparkException("Action has been cancelled")
-    }
   }
 
   override def isCancelled: Boolean = _cancelled
@@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
 
   override def value: Option[Try[T]] = p.future.value
 
-  def jobIds: Seq[Int] = jobs
+  def jobIds: Seq[Int] = subActions.flatMap(_.jobIds)
 
 }
 
+
 private[spark]
 class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T)
   extends JavaFutureAction[T] {
@@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
     Await.ready(futureAction, timeout)
     futureAction.value.get match {
       case scala.util.Success(value) => converter(value)
-      case Failure(exception) =>
+      case scala.util.Failure(exception) =>
         if (isCancelled) {
           throw new CancellationException("Job cancelled").initCause(exception)
         } else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index d5e853613b05bfec83cf9608a40d6e62391fb7e0..14f541f937b4c59a26db191fd7c66e8f633a34f4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -19,13 +19,12 @@ package org.apache.spark.rdd
 
 import java.util.concurrent.atomic.AtomicLong
 
-import org.apache.spark.util.ThreadUtils
-
 import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{Future, ExecutionContext}
 import scala.reflect.ClassTag
 
-import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.util.ThreadUtils
 
 /**
  * A set of asynchronous RDD actions available through an implicit conversion.
@@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
    * Returns a future for retrieving the first num elements of the RDD.
    */
   def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
-    val f = new ComplexFutureAction[Seq[T]]
     val callSite = self.context.getCallSite
-
-    f.run {
-      // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
-      // is a cached thread pool.
-      val results = new ArrayBuffer[T](num)
-      val totalParts = self.partitions.length
-      var partsScanned = 0
-      self.context.setCallSite(callSite)
-      while (results.size < num && partsScanned < totalParts) {
+    val localProperties = self.context.getLocalProperties
+    // Cached thread pool to handle aggregation of subtasks.
+    implicit val executionContext = AsyncRDDActions.futureExecutionContext
+    val results = new ArrayBuffer[T](num)
+    val totalParts = self.partitions.length
+
+    /*
+      Recursively triggers jobs to scan partitions until either the requested
+      number of elements are retrieved, or the partitions to scan are exhausted.
+      This implementation is non-blocking, asynchronously handling the
+      results of each job and triggering the next job using callbacks on futures.
+     */
+    def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
+      if (results.size >= num || partsScanned >= totalParts) {
+        Future.successful(results.toSeq)
+      } else {
         // The number of partitions to try in this iteration. It is ok for this number to be
         // greater than totalParts because we actually cap it at totalParts in runJob.
         var numPartsToTry = 1
@@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
         val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
 
         val buf = new Array[Array[T]](p.size)
-        f.runJob(self,
+        self.context.setCallSite(callSite)
+        self.context.setLocalProperties(localProperties)
+        val job = jobSubmitter.submitJob(self,
           (it: Iterator[T]) => it.take(left).toArray,
           p,
           (index: Int, data: Array[T]) => buf(index) = data,
           Unit)
-
-        buf.foreach(results ++= _.take(num - results.size))
-        partsScanned += numPartsToTry
+        job.flatMap {_ =>
+          buf.foreach(results ++= _.take(num - results.size))
+          continue(partsScanned + numPartsToTry)
+        }
       }
-      results.toSeq
-    }(AsyncRDDActions.futureExecutionContext)
 
-    f
+    new ComplexFutureAction[Seq[T]](continue(0)(_))
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5582720bbcff2d06eb7d10de4e49a3ee2016081d..8d0e0c8624a55f2cd7b81dcc29e0e4ac2daa6089 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.Map
 import scala.collection.mutable.{HashMap, HashSet, Stack}
+import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.existentials
 import scala.language.postfixOps
@@ -610,11 +611,12 @@ class DAGScheduler(
       properties: Properties): Unit = {
     val start = System.nanoTime
     val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
-    waiter.awaitResult() match {
-      case JobSucceeded =>
+    Await.ready(waiter.completionFuture, atMost = Duration.Inf)
+    waiter.completionFuture.value.get match {
+      case scala.util.Success(_) =>
         logInfo("Job %d finished: %s, took %f s".format
           (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
-      case JobFailed(exception: Exception) =>
+      case scala.util.Failure(exception) =>
         logInfo("Job %d failed: %s, took %f s".format
           (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
         // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 382b09422a4a0a19efe71a58281c1e0c076661da..4326135186a73fd10b48d2922bc0e4eb8b0f6713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.scheduler
 
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{Future, Promise}
+
 /**
  * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
  * results to the given handler function.
@@ -28,17 +32,15 @@ private[spark] class JobWaiter[T](
     resultHandler: (Int, T) => Unit)
   extends JobListener {
 
-  private var finishedTasks = 0
-
-  // Is the job as a whole finished (succeeded or failed)?
-  @volatile
-  private var _jobFinished = totalTasks == 0
-
-  def jobFinished: Boolean = _jobFinished
-
+  private val finishedTasks = new AtomicInteger(0)
   // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
   // partition RDDs), we set the jobResult directly to JobSucceeded.
-  private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+  private val jobPromise: Promise[Unit] =
+    if (totalTasks == 0) Promise.successful(()) else Promise()
+
+  def jobFinished: Boolean = jobPromise.isCompleted
+
+  def completionFuture: Future[Unit] = jobPromise.future
 
   /**
    * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
@@ -49,29 +51,17 @@ private[spark] class JobWaiter[T](
     dagScheduler.cancelJob(jobId)
   }
 
-  override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
-    if (_jobFinished) {
-      throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+  override def taskSucceeded(index: Int, result: Any): Unit = {
+    // resultHandler call must be synchronized in case resultHandler itself is not thread safe.
+    synchronized {
+      resultHandler(index, result.asInstanceOf[T])
     }
-    resultHandler(index, result.asInstanceOf[T])
-    finishedTasks += 1
-    if (finishedTasks == totalTasks) {
-      _jobFinished = true
-      jobResult = JobSucceeded
-      this.notifyAll()
+    if (finishedTasks.incrementAndGet() == totalTasks) {
+      jobPromise.success(())
     }
   }
 
-  override def jobFailed(exception: Exception): Unit = synchronized {
-    _jobFinished = true
-    jobResult = JobFailed(exception)
-    this.notifyAll()
-  }
+  override def jobFailed(exception: Exception): Unit =
+    jobPromise.failure(exception)
 
-  def awaitResult(): JobResult = synchronized {
-    while (!_jobFinished) {
-      this.wait()
-    }
-    return jobResult
-  }
 }
diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala
new file mode 100644
index 0000000000000000000000000000000000000000..01694a6e6f741e9e352385f5a78b886c4817f46b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/Smuggle.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.UUID
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import scala.collection.mutable
+
+/**
+  * Utility wrapper to "smuggle" objects into tasks while bypassing serialization.
+  * This is intended for testing purposes, primarily to make locks, semaphores, and
+  * other constructs that would not survive serialization available from within tasks.
+  * A Smuggle reference is itself serializable, but after being serialized and
+  * deserialized, it still refers to the same underlying "smuggled" object, as long
+  * as it was deserialized within the same JVM. This can be useful for tests that
+  * depend on the timing of task completion to be deterministic, since one can "smuggle"
+  * a lock or semaphore into the task, and then the task can block until the test gives
+  * the go-ahead to proceed via the lock.
+  */
+class Smuggle[T] private(val key: Symbol) extends Serializable {
+  def smuggledObject: T = Smuggle.get(key)
+}
+
+
+object Smuggle {
+  /**
+    * Wraps the specified object to be smuggled into a serialized task without
+    * being serialized itself.
+    *
+    * @param smuggledObject
+    * @tparam T
+    * @return Smuggle wrapper around smuggledObject.
+    */
+  def apply[T](smuggledObject: T): Smuggle[T] = {
+    val key = Symbol(UUID.randomUUID().toString)
+    lock.writeLock().lock()
+    try {
+      smuggledObjects += key -> smuggledObject
+    } finally {
+      lock.writeLock().unlock()
+    }
+    new Smuggle(key)
+  }
+
+  private val lock = new ReentrantReadWriteLock
+  private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any]
+
+  private def get[T](key: Symbol) : T = {
+    lock.readLock().lock()
+    try {
+      smuggledObjects(key).asInstanceOf[T]
+    } finally {
+      lock.readLock().unlock()
+    }
+  }
+
+  /**
+    * Implicit conversion of a Smuggle wrapper to the object being smuggled.
+    *
+    * @param smuggle the wrapper to unpack.
+    * @tparam T
+    * @return the smuggled object represented by the wrapper.
+    */
+  implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject
+
+}
diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
index 46516e8d25298c095e9ff9a7207069db341b4693..5483f2b8434aad5ecf9bcf7938adc32db347e15f 100644
--- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
@@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont
         Set(firstJobId, secondJobId))
     }
   }
+
+  test("getJobIdsForGroup() with takeAsync()") {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+    sc.setJobGroup("my-job-group2", "description")
+    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+    val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1)
+    val firstJobId = eventually(timeout(10 seconds)) {
+      firstJobFuture.jobIds.head
+    }
+    eventually(timeout(10 seconds)) {
+      sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId))
+    }
+  }
+
+  test("getJobIdsForGroup() with takeAsync() across multiple partitions") {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+    sc.setJobGroup("my-job-group2", "description")
+    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+    val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
+    val firstJobId = eventually(timeout(10 seconds)) {
+      firstJobFuture.jobIds.head
+    }
+    eventually(timeout(10 seconds)) {
+      sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index ec99f2a1bad660899b62bfa139440c2d0d61abb3..de015ebd5d23704f737e1acb5f7479fe66b296fc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
 
 import java.util.concurrent.Semaphore
 
-import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent._
 import scala.concurrent.duration.Duration
 import scala.concurrent.ExecutionContext.Implicits.global
 
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
 
 class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts {
 
@@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
       Await.result(f, Duration(20, "milliseconds"))
     }
   }
+
+  private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = {
+    val executionContextInvoked = Promise[Unit]
+    val fakeExecutionContext = new ExecutionContext {
+      override def execute(runnable: Runnable): Unit = {
+        executionContextInvoked.success(())
+      }
+      override def reportFailure(t: Throwable): Unit = ()
+    }
+    val starter = Smuggle(new Semaphore(0))
+    starter.drainPermits()
+    val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr}
+    val f = action(rdd)
+    f.onComplete(_ => ())(fakeExecutionContext)
+    // Here we verify that registering the callback didn't cause a thread to be consumed.
+    assert(!executionContextInvoked.isCompleted)
+    // Now allow the executors to proceed with task processing.
+    starter.release(rdd.partitions.length)
+    // Waiting for the result verifies that the tasks were successfully processed.
+    Await.result(executionContextInvoked.future, atMost = 15.seconds)
+  }
+
+  test("SimpleFutureAction callback must not consume a thread while waiting") {
+    testAsyncAction(_.countAsync())
+  }
+
+  test("ComplexFutureAction callback must not consume a thread while waiting") {
+    testAsyncAction((_.takeAsync(100)))
+  }
 }