From f350cd307045c2c02e713225d8f1247f18ba123e Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@apache.org>
Date: Sun, 28 Sep 2014 20:32:54 -0700
Subject: [PATCH] [SPARK-3543] TaskContext remaining cleanup work.

Author: Reynold Xin <rxin@apache.org>

Closes #2560 from rxin/TaskContext and squashes the following commits:

9eff95a [Reynold Xin] [SPARK-3543] remaining cleanup work.
---
 core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala   | 2 +-
 .../main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++-
 .../apache/spark/util/JavaTaskCompletionListenerImpl.java  | 7 +++----
 .../serializer/ProactiveClosureSerializationSuite.scala    | 6 +-----
 .../apache/spark/sql/parquet/ParquetTableOperations.scala  | 4 ++--
 5 files changed, 9 insertions(+), 13 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 036dcc4966..21d0cc7b5c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -194,7 +194,7 @@ class HadoopRDD[K, V](
       val jobConf = getJobConf()
       val inputFormat = getInputFormat(jobConf)
       HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
-        context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+        context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
       reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
 
       // Register an on-task-completion callback to close the input stream.
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 7f578bc5da..67833743f3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -86,7 +86,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     }
     val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
     if (self.partitioner == Some(partitioner)) {
-      self.mapPartitionsWithContext((context, iter) => {
+      self.mapPartitions(iter => {
+        val context = TaskContext.get()
         new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
       }, preservesPartitioning = true)
     } else {
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
index af34cdb03e..0944bf8cd5 100644
--- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -30,10 +30,9 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
   public void onTaskCompletion(TaskContext context) {
     context.isCompleted();
     context.isInterrupted();
-    context.stageId();
-    context.partitionId();
-    context.runningLocally();
-    context.taskMetrics();
+    context.getStageId();
+    context.getPartitionId();
+    context.isRunningLocally();
     context.addTaskCompletionListener(this);
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
index aad6599589..d037e2c19a 100644
--- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -50,8 +50,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
           "flatMap" -> xflatMap _,
           "filter" -> xfilter _,
           "mapPartitions" -> xmapPartitions _,
-          "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
-          "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) {
+          "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) {
     val (name, xf) = transformation
 
     test(s"$name transformations throw proactive serialization exceptions") {
@@ -78,8 +77,5 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
 
   private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = 
     x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
-
-  private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = 
-    x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
   
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index d39e31a7fa..ffb732347d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -289,9 +289,9 @@ case class InsertIntoParquetTable(
     def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
       // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
       // around by taking a mod. We expect that no task will be attempted 2 billion times.
-      val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+      val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
       /* "reduce task" <split #> <attempt # = spark task #> */
-      val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+      val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
         attemptNumber)
       val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
       val format = new AppendingParquetOutputFormat(taskIdOffset)
-- 
GitLab