From 5e34855cf04145cc3b7bae996c2a6e668f144a11 Mon Sep 17 00:00:00 2001
From: Prashant Sharma <prashant.s@imaginea.com>
Date: Fri, 26 Sep 2014 21:29:54 -0700
Subject: [PATCH] [SPARK-3543] Write TaskContext in Java and expose it through
 a static accessor.

Author: Prashant Sharma <prashant.s@imaginea.com>
Author: Shashank Sharma <shashank21j@gmail.com>

Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits:

8ae414c [Shashank Sharma] CR
ee8bd00 [Prashant Sharma] Added internal API in docs comments.
ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated.
a7d5e23 [Prashant Sharma] Added doc comments.
edf945e [Prashant Sharma] Code review git add -A
f716fd1 [Prashant Sharma] introduced thread local for getting the task context.
333c7d6 [Prashant Sharma] Translated Task context from scala to java.
---
 .../java/org/apache/spark/TaskContext.java    | 274 ++++++++++++++++++
 .../scala/org/apache/spark/TaskContext.scala  | 126 --------
 .../main/scala/org/apache/spark/rdd/RDD.scala |   1 +
 .../apache/spark/scheduler/DAGScheduler.scala |   4 +-
 .../org/apache/spark/scheduler/Task.scala     |   6 +-
 .../java/org/apache/spark/JavaAPISuite.java   |   2 +-
 .../org/apache/spark/CacheManagerSuite.scala  |   2 +-
 7 files changed, 284 insertions(+), 131 deletions(-)
 create mode 100644 core/src/main/java/org/apache/spark/TaskContext.java
 delete mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala

diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
new file mode 100644
index 0000000000..09b8ce02bd
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import scala.Function0;
+import scala.Function1;
+import scala.Unit;
+import scala.collection.JavaConversions;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.util.TaskCompletionListener;
+import org.apache.spark.util.TaskCompletionListenerException;
+
+/**
+* :: DeveloperApi ::
+* Contextual information about a task which can be read or mutated during execution.
+*/
+@DeveloperApi
+public class TaskContext implements Serializable {
+
+  private int stageId;
+  private int partitionId;
+  private long attemptId;
+  private boolean runningLocally;
+  private TaskMetrics taskMetrics;
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   * @param runningLocally whether the task is running locally in the driver JVM
+   * @param taskMetrics performance metrics of the task
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
+                     TaskMetrics taskMetrics) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = runningLocally;
+    this.stageId = stageId;
+    this.taskMetrics = taskMetrics;
+  }
+
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   * @param runningLocally whether the task is running locally in the driver JVM
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
+                     Boolean runningLocally) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = runningLocally;
+    this.stageId = stageId;
+    this.taskMetrics = TaskMetrics.empty();
+  }
+
+
+  /**
+   * :: DeveloperApi ::
+   * Contextual information about a task which can be read or mutated during execution.
+   *
+   * @param stageId stage id
+   * @param partitionId index of the partition
+   * @param attemptId the number of attempts to execute this task
+   */
+  @DeveloperApi
+  public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
+    this.attemptId = attemptId;
+    this.partitionId = partitionId;
+    this.runningLocally = false;
+    this.stageId = stageId;
+    this.taskMetrics = TaskMetrics.empty();
+  }
+
+  private static ThreadLocal<TaskContext> taskContext =
+    new ThreadLocal<TaskContext>();
+
+  /**
+  * :: Internal API ::
+  * This is spark internal API, not intended to be called from user programs.
+  */
+  public static void setTaskContext(TaskContext tc) {
+    taskContext.set(tc);
+  }
+
+  public static TaskContext get() {
+    return taskContext.get();
+  }
+
+  /** 
+  * :: Internal API ::
+  */
+  public static void remove() {
+    taskContext.remove();
+  }
+
+  // List of callback functions to execute when the task completes.
+  private transient List<TaskCompletionListener> onCompleteCallbacks =
+    new ArrayList<TaskCompletionListener>();
+
+  // Whether the corresponding task has been killed.
+  private volatile Boolean interrupted = false;
+
+  // Whether the task has completed.
+  private volatile Boolean completed = false;
+
+  /**
+   * Checks whether the task has completed.
+   */
+  public Boolean isCompleted() {
+    return completed;
+  }
+
+  /**
+   * Checks whether the task has been killed.
+   */
+  public Boolean isInterrupted() {
+    return interrupted;
+  }
+
+  /**
+   * Add a (Java friendly) listener to be executed on task completion.
+   * This will be called in all situation - success, failure, or cancellation.
+   * <p/>
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
+    onCompleteCallbacks.add(listener);
+    return this;
+  }
+
+  /**
+   * Add a listener in the form of a Scala closure to be executed on task completion.
+   * This will be called in all situations - success, failure, or cancellation.
+   * <p/>
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
+    onCompleteCallbacks.add(new TaskCompletionListener() {
+      @Override
+      public void onTaskCompletion(TaskContext context) {
+        f.apply(context);
+      }
+    });
+    return this;
+  }
+
+  /**
+   * Add a callback function to be executed on task completion. An example use
+   * is for HadoopRDD to register a callback to close the input stream.
+   * Will be called in any situation - success, failure, or cancellation.
+   *
+   * Deprecated: use addTaskCompletionListener
+   * 
+   * @param f Callback function.
+   */
+  @Deprecated
+  public void addOnCompleteCallback(final Function0<Unit> f) {
+    onCompleteCallbacks.add(new TaskCompletionListener() {
+      @Override
+      public void onTaskCompletion(TaskContext context) {
+        f.apply();
+      }
+    });
+  }
+
+  /**
+   * ::Internal API::
+   * Marks the task as completed and triggers the listeners.
+   */
+  public void markTaskCompleted() throws TaskCompletionListenerException {
+    completed = true;
+    List<String> errorMsgs = new ArrayList<String>(2);
+    // Process complete callbacks in the reverse order of registration
+    List<TaskCompletionListener> revlist =
+      new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
+    Collections.reverse(revlist);
+    for (TaskCompletionListener tcl: revlist) {
+      try {
+        tcl.onTaskCompletion(this);
+      } catch (Throwable e) {
+        errorMsgs.add(e.getMessage());
+      }
+    }
+
+    if (!errorMsgs.isEmpty()) {
+      throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
+    }
+  }
+
+  /**
+   * ::Internal API::
+   * Marks the task for interruption, i.e. cancellation.
+   */
+  public void markInterrupted() {
+    interrupted = true;
+  }
+
+  @Deprecated
+  /** Deprecated: use getStageId() */
+  public int stageId() {
+    return stageId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getPartitionId() */
+  public int partitionId() {
+    return partitionId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getAttemptId() */
+  public long attemptId() {
+    return attemptId;
+  }
+
+  @Deprecated
+  /** Deprecated: use getRunningLocally() */
+  public boolean runningLocally() {
+    return runningLocally;
+  }
+
+  public boolean getRunningLocally() {
+    return runningLocally;
+  }
+
+  public int getStageId() {
+    return stageId;
+  }
+
+  public int getPartitionId() {
+    return partitionId;
+  }
+
+  public long getAttemptId() {
+    return attemptId;
+  }  
+
+  /** ::Internal API:: */
+  public TaskMetrics taskMetrics() {
+    return taskMetrics;
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
deleted file mode 100644
index 51b3e4d5e0..0000000000
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
-
-
-/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
-@DeveloperApi
-class TaskContext(
-    val stageId: Int,
-    val partitionId: Int,
-    val attemptId: Long,
-    val runningLocally: Boolean = false,
-    private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
-  extends Serializable with Logging {
-
-  @deprecated("use partitionId", "0.8.1")
-  def splitId = partitionId
-
-  // List of callback functions to execute when the task completes.
-  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
-
-  // Whether the corresponding task has been killed.
-  @volatile private var interrupted: Boolean = false
-
-  // Whether the task has completed.
-  @volatile private var completed: Boolean = false
-
-  /** Checks whether the task has completed. */
-  def isCompleted: Boolean = completed
-
-  /** Checks whether the task has been killed. */
-  def isInterrupted: Boolean = interrupted
-
-  // TODO: Also track whether the task has completed successfully or with exception.
-
-  /**
-   * Add a (Java friendly) listener to be executed on task completion.
-   * This will be called in all situation - success, failure, or cancellation.
-   *
-   * An example use is for HadoopRDD to register a callback to close the input stream.
-   */
-  def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
-    onCompleteCallbacks += listener
-    this
-  }
-
-  /**
-   * Add a listener in the form of a Scala closure to be executed on task completion.
-   * This will be called in all situation - success, failure, or cancellation.
-   *
-   * An example use is for HadoopRDD to register a callback to close the input stream.
-   */
-  def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
-    onCompleteCallbacks += new TaskCompletionListener {
-      override def onTaskCompletion(context: TaskContext): Unit = f(context)
-    }
-    this
-  }
-
-  /**
-   * Add a callback function to be executed on task completion. An example use
-   * is for HadoopRDD to register a callback to close the input stream.
-   * Will be called in any situation - success, failure, or cancellation.
-   * @param f Callback function.
-   */
-  @deprecated("use addTaskCompletionListener", "1.1.0")
-  def addOnCompleteCallback(f: () => Unit) {
-    onCompleteCallbacks += new TaskCompletionListener {
-      override def onTaskCompletion(context: TaskContext): Unit = f()
-    }
-  }
-
-  /** Marks the task as completed and triggers the listeners. */
-  private[spark] def markTaskCompleted(): Unit = {
-    completed = true
-    val errorMsgs = new ArrayBuffer[String](2)
-    // Process complete callbacks in the reverse order of registration
-    onCompleteCallbacks.reverse.foreach { listener =>
-      try {
-        listener.onTaskCompletion(this)
-      } catch {
-        case e: Throwable =>
-          errorMsgs += e.getMessage
-          logError("Error in TaskCompletionListener", e)
-      }
-    }
-    if (errorMsgs.nonEmpty) {
-      throw new TaskCompletionListenerException(errorMsgs)
-    }
-  }
-
-  /** Marks the task for interruption, i.e. cancellation. */
-  private[spark] def markInterrupted(): Unit = {
-    interrupted = true
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0e90caa5c9..ba712c9d77 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
    * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
    */
   @DeveloperApi
+  @deprecated("use TaskContext.get", "1.2.0")
   def mapPartitionsWithContext[U: ClassTag](
       f: (TaskContext, Iterator[T]) => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b2774dfc47..32cf29ed14 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,12 +634,14 @@ class DAGScheduler(
       val rdd = job.finalStage.rdd
       val split = rdd.partitions(job.partitions(0))
       val taskContext =
-        new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
+        new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+      TaskContext.setTaskContext(taskContext)
       try {
         val result = job.func(taskContext, rdd.iterator(split, taskContext))
         job.listener.taskSucceeded(0, result)
       } finally {
         taskContext.markTaskCompleted()
+        TaskContext.remove()
       }
     } catch {
       case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6aa0cca068..bf73f6f7bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
 
   final def run(attemptId: Long): T = {
-    context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+    context = new TaskContext(stageId, partitionId, attemptId, false)
+    TaskContext.setTaskContext(context)
     context.taskMetrics.hostname = Utils.localHostName()
     taskThread = Thread.currentThread()
     if (_killed) {
@@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
     if (interruptThread && taskThread != null) {
       taskThread.interrupt()
     }
-  }
+    TaskContext.remove()
+  }  
 }
 
 /**
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index b8c23d524e..4a07843544 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
+    TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
     Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
   }
 
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 90dcadcffd..d735010d7c 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     }
 
     whenExecuting(blockManager) {
-      val context = new TaskContext(0, 0, 0, runningLocally = true)
+      val context = new TaskContext(0, 0, 0, true)
       val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
       assert(value.toList === List(1, 2, 3, 4))
     }
-- 
GitLab