From f350cd307045c2c02e713225d8f1247f18ba123e Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Sun, 28 Sep 2014 20:32:54 -0700 Subject: [PATCH] [SPARK-3543] TaskContext remaining cleanup work. Author: Reynold Xin <rxin@apache.org> Closes #2560 from rxin/TaskContext and squashes the following commits: 9eff95a [Reynold Xin] [SPARK-3543] remaining cleanup work. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++- .../apache/spark/util/JavaTaskCompletionListenerImpl.java | 7 +++---- .../serializer/ProactiveClosureSerializationSuite.scala | 6 +----- .../apache/spark/sql/parquet/ParquetTableOperations.scala | 4 ++-- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 036dcc4966..21d0cc7b5c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -194,7 +194,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7f578bc5da..67833743f3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -86,7 +86,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitionsWithContext((context, iter) => { + self.mapPartitions(iter => { + val context = TaskContext.get() new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index af34cdb03e..0944bf8cd5 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,10 +30,9 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.runningLocally(); - context.taskMetrics(); + context.getStageId(); + context.getPartitionId(); + context.isRunningLocally(); context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index aad6599589..d037e2c19a 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -50,8 +50,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapPartitions" -> xmapPartitions _, - "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) { + "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { @@ -78,8 +77,5 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - - private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index d39e31a7fa..ffb732347d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) -- GitLab