diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 08e3f670f57f6ea63feeb59b2b2a8a397acda394..67d45723badd8b4327a558664ffa37abdd5cc061 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.python -import java.io.{File, DataInputStream, IOException} -import java.net.{Socket, SocketException, InetAddress} +import java.io.{OutputStreamWriter, File, DataInputStream, IOException} +import java.net.{ServerSocket, Socket, SocketException, InetAddress} import scala.collection.JavaConversions._ @@ -26,11 +26,30 @@ import org.apache.spark._ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) extends Logging { + + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon + // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently + // only works on UNIX-based systems now because it uses signals for child management, so we can + // also fall back to launching workers (pyspark/worker.py) directly. + val useDaemon = !System.getProperty("os.name").startsWith("Windows") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 def create(): Socket = { + if (useDaemon) { + createThroughDaemon() + } else { + createSimpleWorker() + } + } + + /** + * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself + * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + */ + private def createThroughDaemon(): Socket = { synchronized { // Start the daemon if it hasn't been started startDaemon() @@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Launch a worker by executing worker.py directly and telling it to connect to us. + */ + private def createSimpleWorker(): Socket = { + var serverSocket: ServerSocket = null + try { + serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + + // Create and start the worker + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") + workerEnv.put("PYTHONPATH", pythonPath) + val worker = pb.start() + + // Redirect the worker's stderr to ours + new Thread("stderr reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val in = worker.getErrorStream + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + // Redirect worker's stdout to our stderr + new Thread("stdout reader for " + pythonExec) { + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val in = worker.getInputStream + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + // Tell the worker our port + val out = new OutputStreamWriter(worker.getOutputStream) + out.write(serverSocket.getLocalPort + "\n") + out.flush() + + // Wait for it to connect to our socket + serverSocket.setSoTimeout(10000) + try { + return serverSocket.accept() + } catch { + case e: Exception => + throw new SparkException("Python worker did not connect back in time", e) + } + } finally { + if (serverSocket != null) { + serverSocket.close() + } + } + null + } + def stop() { stopDaemon() } @@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { + setDaemon(true) override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. val in = daemon.getErrorStream - var buf = new Array[Byte](1024) + val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { System.err.write(buf, 0, len) @@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect further stdout output to our stderr new Thread("stdout reader for " + pythonExec) { + setDaemon(true) override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - var buf = new Array[Byte](1024) + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { System.err.write(buf, 0, len) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 695f6dfb8444c1ef72f1649438854d402c5e0011..d63c2aaef772de62eef3bf913ad4a4859cf30512 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,6 +21,7 @@ Worker that receives input from Piped RDD. import os import sys import time +import socket import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -94,7 +95,9 @@ def main(infile, outfile): if __name__ == '__main__': - # Redirect stdout to stderr so that users must return values from functions. - old_stdout = os.fdopen(os.dup(1), 'w') - os.dup2(2, 1) - main(sys.stdin, old_stdout) + # Read a local port to connect to from stdin + java_port = int(sys.stdin.readline()) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("127.0.0.1", java_port)) + sock_file = sock.makefile("a+", 65536) + main(sock_file, sock_file)