Skip to content
Snippets Groups Projects
Commit 59195c68 authored by Josh Rosen's avatar Josh Rosen
Browse files

Update PySpark for compatibility with TaskContext.

parent c5cee53f
No related branches found
No related tags found
No related merge requests found
......@@ -8,10 +8,7 @@ import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark.SparkEnv
import spark.Split
import spark.RDD
import spark.OneToOneDependency
import spark._
import spark.rdd.PipedRDD
......@@ -34,7 +31,7 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[Array[Byte]] = {
override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
......@@ -74,7 +71,7 @@ private[spark] class PythonRDD[T: ClassManifest](
out.println(elem)
}
out.flush()
for (elem <- parent.iterator(split)) {
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dOut)
}
dOut.flush()
......@@ -123,8 +120,8 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) =
prev.iterator(split).grouped(2).map {
override def compute(split: Split, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
}
......
......@@ -335,9 +335,10 @@ class RDD(object):
"""
items = []
splits = self._jrdd.splits()
taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0)
while len(items) < num and splits:
split = splits.pop(0)
iterator = self._jrdd.iterator(split)
iterator = self._jrdd.iterator(split, taskContext)
items.extend(self._collect_iterator_through_file(iterator))
return items[:num]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment