diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 465cc1fa7d88dcb08daf9ab2f6d1e46650219179..64e354e2e3eac8d747fd65310efcb65986989780 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -23,6 +23,7 @@ import scala.util.Try
 
 import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
 import org.apache.spark.scheduler.JobFailed
+import org.apache.spark.rdd.RDD
 
 
 /**
@@ -170,14 +171,13 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
   }
 
   /**
-   * Executes some action enclosed in the closure. This execution of func is wrapped in a
-   * synchronized block to guarantee that this promise can only be cancelled when the task is
-   * waiting for
+   * Executes some action enclosed in the closure. To properly enable cancellation, the closure
+   * should use runJob implementation in this promise. See takeAsync for example.
    */
   def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
     thread = Thread.currentThread
     try {
-      this.success(this.synchronized {
+      this.success({
         if (cancelled) {
           // This action has been cancelled before this thread even started running.
           throw new InterruptedException
@@ -191,6 +191,38 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
     }
   }
 
+  /**
+   * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
+   * to enable cancellation.
+   */
+  def runJob[T, U, R](
+      rdd: RDD[T],
+      processPartition: Iterator[T] => U,
+      partitions: Seq[Int],
+      partitionResultHandler: (Int, U) => Unit,
+      resultFunc: => R) {
+    // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+    // command need to be in an atomic block.
+    val job = this.synchronized {
+      if (!cancelled) {
+        rdd.context.submitJob(rdd, processPartition, partitions, partitionResultHandler, resultFunc)
+      } else {
+        throw new SparkException("action has been cancelled")
+      }
+    }
+
+    // Wait for the job to complete. If the action is cancelled (with an interrupt),
+    // cancel the job and stop the execution. This is not in a synchronized block because
+    // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
+    try {
+      Await.ready(job, Duration.Inf)
+    } catch {
+      case e: InterruptedException =>
+        job.cancel()
+        throw new SparkException("action has been cancelled")
+    }
+  }
+
   /**
    * Returns whether the promise has been cancelled.
    */
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 6806b8730b2482689991400d4930eb814a879756..579832427e11376525b6556f5ed1aab307569cc1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -20,8 +20,6 @@ package org.apache.spark.rdd
 import java.util.concurrent.atomic.AtomicLong
 
 import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
 import scala.concurrent.ExecutionContext.Implicits.global
 
 import org.apache.spark.{Logging, CancellablePromise, FutureAction}
@@ -90,22 +88,12 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
         val left = num - buf.size
         val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
 
-        val job = self.context.submitJob(
-          self,
+        promise.runJob(self,
           (it: Iterator[T]) => it.take(left).toArray,
           p,
           (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size),
           Unit)
 
-        // Wait for the job to complete. If the action is cancelled (with an interrupt),
-        // cancel the job and stop the execution.
-        try {
-          Await.result(job, Duration.Inf)
-        } catch {
-          case e: InterruptedException =>
-            job.cancel()
-            throw e
-        }
         partsScanned += numPartsToTry
       }
       buf.toSeq
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 758670bdbf8e730c2e9c844ac9caa56e8e79a03b..029f24a51be4637240356933006f85812328f0ff 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -20,16 +20,13 @@ package org.apache.spark.rdd
 import java.util.concurrent.Semaphore
 import java.util.concurrent.atomic.AtomicInteger
 
-import scala.concurrent.Await
 import scala.concurrent.future
-import scala.concurrent.duration._
 import scala.concurrent.ExecutionContext.Implicits.global
 
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
 import org.apache.spark.SparkContext._
 import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
 import org.apache.spark.scheduler._
 
 
@@ -81,135 +78,154 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
     assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
   }
 
-//
-//  test("countAsync") {
-//    assert(zeroPartRdd.countAsync().get() === 0)
-//    assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
-//  }
-//
-//  test("collectAsync") {
-//    assert(zeroPartRdd.collectAsync().get() === Seq.empty)
-//
-//    // Note that we sort the collected output because the order is indeterministic.
-//    val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
-//    assert(collected === (1 to 1000))
-//  }
-//
-//  test("foreachAsync") {
-//    zeroPartRdd.foreachAsync(i => Unit).get()
-//
-//    val accum = sc.accumulator(0)
-//    sc.parallelize(1 to 1000, 3).foreachAsync { i =>
-//      accum += 1
-//    }.get()
-//    assert(accum.value === 1000)
-//  }
-//
-//  test("foreachPartitionAsync") {
-//    zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
-//
-//    val accum = sc.accumulator(0)
-//    sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
-//      accum += 1
-//    }.get()
-//    assert(accum.value === 9)
-//  }
-//
-//  test("takeAsync") {
-//    def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
-//      // Note that we sort the collected output because the order is indeterministic.
-//      assert(rdd.takeAsync(num).get().size === input.take(num).size)
-//    }
-//    val input = Range(1, 1000)
-//
-//    var nums = sc.parallelize(input, 1)
-//    for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
-//      testTake(nums, input, num)
-//    }
-//
-//    nums = sc.parallelize(input, 2)
-//    for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
-//      testTake(nums, input, num)
-//    }
-//
-//    nums = sc.parallelize(input, 100)
-//    for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
-//      testTake(nums, input, num)
-//    }
-//
-//    nums = sc.parallelize(input, 1000)
-//    for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
-//      testTake(nums, input, num)
-//    }
-//  }
-//
-//  /**
-//   * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
-//   * of a successful job execution.
-//   */
-//  test("async success handling") {
-//    val f = sc.parallelize(1 to 10, 2).countAsync()
-//
-//    // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
-//    // finishes execution.
-//    val sem = new Semaphore(0)
-//
-//    AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
-//    f.onComplete {
-//      case scala.util.Success(res) =>
-//        AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
-//        sem.release()
-//      case scala.util.Failure(e) =>
-//        throw new Exception("Task should succeed")
-//        sem.release()
-//    }
-//    f.onSuccess { case a: Any =>
-//      AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
-//      sem.release()
-//    }
-//    f.onFailure { case t =>
-//      throw new Exception("Task should succeed")
-//    }
-//    assert(f.get() === 10)
-//    sem.acquire(2)
-//    assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
-//  }
-//
-//  /**
-//   * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
-//   * of a failed job execution.
-//   */
-//  test("async failure handling") {
-//    val f = sc.parallelize(1 to 10, 2).map { i =>
-//      throw new Exception("intentional"); i
-//    }.countAsync()
-//
-//    // This semaphore is used to make sure our final assert waits until onComplete / onFailure
-//    // finishes execution.
-//    val sem = new Semaphore(0)
-//
-//    AsyncRDDActionsSuite.asyncFailureHappend.set(0)
-//    f.onComplete {
-//      case scala.util.Success(res) =>
-//        throw new Exception("Task should fail")
-//        sem.release()
-//      case scala.util.Failure(e) =>
-//        AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
-//        sem.release()
-//    }
-//    f.onSuccess { case a: Any =>
-//      throw new Exception("Task should fail")
-//    }
-//    f.onFailure { case t =>
-//      AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
-//      sem.release()
-//    }
-//    intercept[SparkException] {
-//      f.get()
-//    }
-//    sem.acquire(2)
-//    assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
-//  }
+  test("cancelling take action after some tasks have been launched") {
+    // Add a listener to release the semaphore once any tasks are launched.
+    val sem = new Semaphore(0)
+    sc.dagScheduler.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem.release()
+      }
+    })
+    val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+    future {
+      sem.acquire()
+      f.cancel()
+    }
+    val e = intercept[SparkException] { f.get() }
+    assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+  }
+
+  test("countAsync") {
+    assert(zeroPartRdd.countAsync().get() === 0)
+    assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
+  }
+
+  test("collectAsync") {
+    assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+
+    // Note that we sort the collected output because the order is indeterministic.
+    val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
+    assert(collected === (1 to 1000))
+  }
+
+  test("foreachAsync") {
+    zeroPartRdd.foreachAsync(i => Unit).get()
+
+    val accum = sc.accumulator(0)
+    sc.parallelize(1 to 1000, 3).foreachAsync { i =>
+      accum += 1
+    }.get()
+    assert(accum.value === 1000)
+  }
+
+  test("foreachPartitionAsync") {
+    zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
+
+    val accum = sc.accumulator(0)
+    sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
+      accum += 1
+    }.get()
+    assert(accum.value === 9)
+  }
+
+  test("takeAsync") {
+    def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+      // Note that we sort the collected output because the order is indeterministic.
+      val expected = input.take(num).size
+      val saw = rdd.takeAsync(num).get().size
+      assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)"
+        .format(rdd.partitions.size, expected, saw))
+    }
+    val input = Range(1, 1000)
+
+    var rdd = sc.parallelize(input, 1)
+    for (num <- Seq(0, 1, 999, 1000)) {
+      testTake(rdd, input, num)
+    }
+
+    rdd = sc.parallelize(input, 2)
+    for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+      testTake(rdd, input, num)
+    }
+
+    rdd = sc.parallelize(input, 100)
+    for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
+      testTake(rdd, input, num)
+    }
+
+    rdd = sc.parallelize(input, 1000)
+    for (num <- Seq(0, 1, 3, 999, 1000)) {
+      testTake(rdd, input, num)
+    }
+  }
+
+  /**
+   * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+   * of a successful job execution.
+   */
+  test("async success handling") {
+    val f = sc.parallelize(1 to 10, 2).countAsync()
+
+    // This semaphore is used to make sure our final assert waits until onComplete / onSuccess
+    // finishes execution.
+    val sem = new Semaphore(0)
+
+    AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
+    f.onComplete {
+      case scala.util.Success(res) =>
+        AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
+        sem.release()
+      case scala.util.Failure(e) =>
+        throw new Exception("Task should succeed")
+        sem.release()
+    }
+    f.onSuccess { case a: Any =>
+      AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
+      sem.release()
+    }
+    f.onFailure { case t =>
+      throw new Exception("Task should succeed")
+    }
+    assert(f.get() === 10)
+    sem.acquire(2)
+    assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
+  }
+
+  /**
+   * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+   * of a failed job execution.
+   */
+  test("async failure handling") {
+    val f = sc.parallelize(1 to 10, 2).map { i =>
+      throw new Exception("intentional"); i
+    }.countAsync()
+
+    // This semaphore is used to make sure our final assert waits until onComplete / onFailure
+    // finishes execution.
+    val sem = new Semaphore(0)
+
+    AsyncRDDActionsSuite.asyncFailureHappend.set(0)
+    f.onComplete {
+      case scala.util.Success(res) =>
+        throw new Exception("Task should fail")
+        sem.release()
+      case scala.util.Failure(e) =>
+        AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
+        sem.release()
+    }
+    f.onSuccess { case a: Any =>
+      throw new Exception("Task should fail")
+    }
+    f.onFailure { case t =>
+      AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
+      sem.release()
+    }
+    intercept[SparkException] {
+      f.get()
+    }
+    sem.acquire(2)
+    assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
+  }
 }
 
 object AsyncRDDActionsSuite {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 31f97fc1391e862a04a71aee36e0438aba4245bb..d7e9ccafb30c7370eb7276b112732d1c036409e8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
       }
     }
     visit(sums)
-    assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+    assert(deps.size === 3) // ShuffledRDD, ParallelCollection, InterruptibleRDD.
   }
 
   test("join") {