diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
new file mode 100644
index 0000000000000000000000000000000000000000..09b8ce02bd3d8101d28d5cf8d949277af9ddf126
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import scala.Function0;
+import scala.Function1;
+import scala.Unit;
+import scala.collection.JavaConversions;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.util.TaskCompletionListener;
+import org.apache.spark.util.TaskCompletionListenerException;
+
+/**
+* :: DeveloperApi ::
+* Contextual information about a task which can be read or mutated during execution.
+*/
+@DeveloperApi
+public class TaskContext implements Serializable {
+
+  private int stageId;
+  private int partitionId;
+  private long attemptId;
+  private boolean runningLocally;
+  private TaskMetrics taskMetrics;
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   * @param runningLocally whether the task is running locally in the driver JVM
+   * @param taskMetrics performance metrics of the task
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
+                     TaskMetrics taskMetrics) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = runningLocally;
+    this.stageId = stageId;
+    this.taskMetrics = taskMetrics;
+  }
+
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   * @param runningLocally whether the task is running locally in the driver JVM
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
+                     Boolean runningLocally) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = runningLocally;
+    this.stageId = stageId;
+    this.taskMetrics = TaskMetrics.empty();
+  }
+
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = false;
+    this.stageId = stageId;
+    this.taskMetrics = TaskMetrics.empty();
+  }
+
+  private static ThreadLocal<TaskContext> taskContext =
+    new ThreadLocal<TaskContext>();
+
+  /**
+  * :: Internal API ::
+  * This is spark internal API, not intended to be called from user programs.
+  */
+  public static void setTaskContext(TaskContext tc) {
+    taskContext.set(tc);
+  }
+
+  public static TaskContext get() {
+    return taskContext.get();
+  }
+
+  /** 
+  * :: Internal API ::
+  */
+  public static void remove() {
+    taskContext.remove();
+  }
+
+  // List of callback functions to execute when the task completes.
+  private transient List<TaskCompletionListener> onCompleteCallbacks =
+    new ArrayList<TaskCompletionListener>();
+
+  // Whether the corresponding task has been killed.
+  private volatile Boolean interrupted = false;
+
+  // Whether the task has completed.
+  private volatile Boolean completed = false;
+
+  /**
+   * Checks whether the task has completed.
+   */
+  public Boolean isCompleted() {
+    return completed;
+  }
+
+  /**
+   * Checks whether the task has been killed.
+   */
+  public Boolean isInterrupted() {
+    return interrupted;
+  }
+
+  /**
+   * Add a (Java friendly) listener to be executed on task completion.
+   * This will be called in all situation - success, failure, or cancellation.
+   * <p/>
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
+    onCompleteCallbacks.add(listener);
+    return this;
+  }
+
+  /**
+   * Add a listener in the form of a Scala closure to be executed on task completion.
+   * This will be called in all situations - success, failure, or cancellation.
+   * <p/>
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
+    onCompleteCallbacks.add(new TaskCompletionListener() {
+      @Override
+      public void onTaskCompletion(TaskContext context) {
+        f.apply(context);
+      }
+    });
+    return this;
+  }
+
+  /**
+   * Add a callback function to be executed on task completion. An example use
+   * is for HadoopRDD to register a callback to close the input stream.
+   * Will be called in any situation - success, failure, or cancellation.
+   *
+   * Deprecated: use addTaskCompletionListener
+   * 
+   * @param f Callback function.
+   */
+  @Deprecated
+  public void addOnCompleteCallback(final Function0<Unit> f) {
+    onCompleteCallbacks.add(new TaskCompletionListener() {
+      @Override
+      public void onTaskCompletion(TaskContext context) {
+        f.apply();
+      }
+    });
+  }
+
+  /**
+   * ::Internal API::
+   * Marks the task as completed and triggers the listeners.
+   */
+  public void markTaskCompleted() throws TaskCompletionListenerException {
+    completed = true;
+    List<String> errorMsgs = new ArrayList<String>(2);
+    // Process complete callbacks in the reverse order of registration
+    List<TaskCompletionListener> revlist =
+      new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
+    Collections.reverse(revlist);
+    for (TaskCompletionListener tcl: revlist) {
+      try {
+        tcl.onTaskCompletion(this);
+      } catch (Throwable e) {
+        errorMsgs.add(e.getMessage());
+      }
+    }
+
+    if (!errorMsgs.isEmpty()) {
+      throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
+    }
+  }
+
+  /**
+   * ::Internal API::
+   * Marks the task for interruption, i.e. cancellation.
+   */
+  public void markInterrupted() {
+    interrupted = true;
+  }
+
+  @Deprecated
+  /** Deprecated: use getStageId() */
+  public int stageId() {
+    return stageId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getPartitionId() */
+  public int partitionId() {
+    return partitionId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getAttemptId() */
+  public long attemptId() {
+    return attemptId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getRunningLocally() */
+  public boolean runningLocally() {
+    return runningLocally;
+  }
+
+  public boolean getRunningLocally() {
+    return runningLocally;
+  }
+
+  public int getStageId() {
+    return stageId;
+  }
+
+  public int getPartitionId() {
+    return partitionId;
+  }
+
+  public long getAttemptId() {
+    return attemptId;
+  }  
+
+  /** ::Internal API:: */
+  public TaskMetrics taskMetrics() {
+    return taskMetrics;
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
deleted file mode 100644
index 51b3e4d5e09362cd1dfff9e9edfa85e90fac99a3..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
-
-
-/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
-@DeveloperApi
-class TaskContext(
-    val stageId: Int,
-    val partitionId: Int,
-    val attemptId: Long,
-    val runningLocally: Boolean = false,
-    private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
-  extends Serializable with Logging {
-
-  @deprecated("use partitionId", "0.8.1")
-  def splitId = partitionId
-
-  // List of callback functions to execute when the task completes.
-  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
-
-  // Whether the corresponding task has been killed.
-  @volatile private var interrupted: Boolean = false
-
-  // Whether the task has completed.
-  @volatile private var completed: Boolean = false
-
-  /** Checks whether the task has completed. */
-  def isCompleted: Boolean = completed
-
-  /** Checks whether the task has been killed. */
-  def isInterrupted: Boolean = interrupted
-
-  // TODO: Also track whether the task has completed successfully or with exception.
-
-  /**
-   * Add a (Java friendly) listener to be executed on task completion.
-   * This will be called in all situation - success, failure, or cancellation.
-   *
-   * An example use is for HadoopRDD to register a callback to close the input stream.
-   */
-  def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
-    onCompleteCallbacks += listener
-    this
-  }
-
-  /**
-   * Add a listener in the form of a Scala closure to be executed on task completion.
-   * This will be called in all situation - success, failure, or cancellation.
-   *
-   * An example use is for HadoopRDD to register a callback to close the input stream.
-   */
-  def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
-    onCompleteCallbacks += new TaskCompletionListener {
-      override def onTaskCompletion(context: TaskContext): Unit = f(context)
-    }
-    this
-  }
-
-  /**
-   * Add a callback function to be executed on task completion. An example use
-   * is for HadoopRDD to register a callback to close the input stream.
-   * Will be called in any situation - success, failure, or cancellation.
-   * @param f Callback function.
-   */
-  @deprecated("use addTaskCompletionListener", "1.1.0")
-  def addOnCompleteCallback(f: () => Unit) {
-    onCompleteCallbacks += new TaskCompletionListener {
-      override def onTaskCompletion(context: TaskContext): Unit = f()
-    }
-  }
-
-  /** Marks the task as completed and triggers the listeners. */
-  private[spark] def markTaskCompleted(): Unit = {
-    completed = true
-    val errorMsgs = new ArrayBuffer[String](2)
-    // Process complete callbacks in the reverse order of registration
-    onCompleteCallbacks.reverse.foreach { listener =>
-      try {
-        listener.onTaskCompletion(this)
-      } catch {
-        case e: Throwable =>
-          errorMsgs += e.getMessage
-          logError("Error in TaskCompletionListener", e)
-      }
-    }
-    if (errorMsgs.nonEmpty) {
-      throw new TaskCompletionListenerException(errorMsgs)
-    }
-  }
-
-  /** Marks the task for interruption, i.e. cancellation. */
-  private[spark] def markInterrupted(): Unit = {
-    interrupted = true
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0e90caa5c9ca72fb4b79696d20d6408a7db7594f..ba712c9d7776ffaf027f06485124f4eaa76436a1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
    * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
    */
   @DeveloperApi
+  @deprecated("use TaskContext.get", "1.2.0")
   def mapPartitionsWithContext[U: ClassTag](
       f: (TaskContext, Iterator[T]) => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b2774dfc475537fcc0c1492fdbc691b95227615d..32cf29ed140e6a69d43ae83187f37c17a4a6bbdf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,12 +634,14 @@ class DAGScheduler(
       val rdd = job.finalStage.rdd
       val split = rdd.partitions(job.partitions(0))
       val taskContext =
-        new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
+        new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+      TaskContext.setTaskContext(taskContext)
       try {
         val result = job.func(taskContext, rdd.iterator(split, taskContext))
         job.listener.taskSucceeded(0, result)
       } finally {
         taskContext.markTaskCompleted()
+        TaskContext.remove()
       }
     } catch {
       case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6aa0cca06878dea255fdab0bb19c796a1a2ccac2..bf73f6f7bd0e15de1b7c98cc5ebf00e065b8be0a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
 
   final def run(attemptId: Long): T = {
-    context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+    context = new TaskContext(stageId, partitionId, attemptId, false)
+    TaskContext.setTaskContext(context)
     context.taskMetrics.hostname = Utils.localHostName()
     taskThread = Thread.currentThread()
     if (_killed) {
@@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
     if (interruptThread && taskThread != null) {
       taskThread.interrupt()
     }
-  }
+    TaskContext.remove()
+  }  
 }
 
 /**
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index b8c23d524e00b2063d6e70432e95a48ec1d6381a..4a078435447e5430255c6654de8739e6f51ad811 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
+    TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
     Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
   }
 
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 90dcadcffd091183040a0eb99922cf8960c0fb36..d735010d7c9d597cba1ea5f1b29520f3d0c79f9b 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     }
 
     whenExecuting(blockManager) {
-      val context = new TaskContext(0, 0, 0, runningLocally = true)
+      val context = new TaskContext(0, 0, 0, true)
       val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
       assert(value.toList === List(1, 2, 3, 4))
     }