Skip to content
Snippets Groups Projects
Commit 6550e5e6 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Allow PySpark to launch worker.py directly on Windows

parent 3c520fea
No related branches found
No related tags found
No related merge requests found
......@@ -17,8 +17,8 @@
package org.apache.spark.api.python
import java.io.{File, DataInputStream, IOException}
import java.net.{Socket, SocketException, InetAddress}
import java.io.{OutputStreamWriter, File, DataInputStream, IOException}
import java.net.{ServerSocket, Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
......@@ -26,11 +26,30 @@ import org.apache.spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
// Because forking processes from Java is expensive, we prefer to launch a single Python daemon
// (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
// only works on UNIX-based systems now because it uses signals for child management, so we can
// also fall back to launching workers (pyspark/worker.py) directly.
val useDaemon = !System.getProperty("os.name").startsWith("Windows")
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
if (useDaemon) {
createThroughDaemon()
} else {
createSimpleWorker()
}
}
/**
* Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
......@@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
/**
* Launch a worker by executing worker.py directly and telling it to connect to us.
*/
private def createSimpleWorker(): Socket = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()
// Redirect the worker's stderr to ours
new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getErrorStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Redirect worker's stdout to our stderr
new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getInputStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Tell the worker our port
val out = new OutputStreamWriter(worker.getOutputStream)
out.write(serverSocket.getLocalPort + "\n")
out.flush()
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
}
} finally {
if (serverSocket != null) {
serverSocket.close()
}
}
null
}
def stop() {
stopDaemon()
}
......@@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = daemon.getErrorStream
var buf = new Array[Byte](1024)
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
......@@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
var buf = new Array[Byte](1024)
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
......
......@@ -21,6 +21,7 @@ Worker that receives input from Piped RDD.
import os
import sys
import time
import socket
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
......@@ -94,7 +95,9 @@ def main(infile, outfile):
if __name__ == '__main__':
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = os.fdopen(os.dup(1), 'w')
os.dup2(2, 1)
main(sys.stdin, old_stdout)
# Read a local port to connect to from stdin
java_port = int(sys.stdin.readline())
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("a+", 65536)
main(sock_file, sock_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment