diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f76616a4c451ca3b2b960724efb584dbb3dfcb42..dc48378fdc0d77f3c459458ff97aab4f0893ac17 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -8,10 +8,7 @@ import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast -import spark.SparkEnv -import spark.Split -import spark.RDD -import spark.OneToOneDependency +import spark._ import spark.rdd.PipedRDD @@ -34,7 +31,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[Array[Byte]] = { + override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -74,7 +71,7 @@ private[spark] class PythonRDD[T: ClassManifest]( out.println(elem) } out.flush() - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() @@ -123,8 +120,8 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = - prev.iterator(split).grouped(2).map { + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 203f7377d2c049182afdc5d114392ab355e9a22c..21dda31c4e435d2af0766d7a5770729374315ab4 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -335,9 +335,10 @@ class RDD(object): """ items = [] splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) while len(items) < num and splits: split = splits.pop(0) - iterator = self._jrdd.iterator(split) + iterator = self._jrdd.iterator(split, taskContext) items.extend(self._collect_iterator_through_file(iterator)) return items[:num]