Skip to content
Snippets Groups Projects
Commit ec31e68d authored by root's avatar root
Browse files

Fixed PySpark perf regression by not using socket.makefile(), and improved

debuggability by letting "print" statements show up in the executor's stderr

Conflicts:
	core/src/main/scala/spark/api/python/PythonRDD.scala
parent 3296d132
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest](
accumulator: Accumulator[JList[Array[Byte]]]) accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) { extends RDD[Array[Byte]](parent) {
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String], def this(parent: RDD[T], command: String, envVars: JMap[String, String],
...@@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) { new Thread("stdin writer for " + pythonExec) {
override def run() { override def run() {
SparkEnv.set(env) SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream) val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream) val printOut = new PrintWriter(stream)
// Partition index // Partition index
...@@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}.start() }.start()
// Return an iterator that read lines from the process's stdout // Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] { return new Iterator[Array[Byte]] {
def next(): Array[Byte] = { def next(): Array[Byte] = {
val obj = _nextObj val obj = _nextObj
...@@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) ...@@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] { extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname") Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
...@@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) ...@@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket // This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort) val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size) out.writeInt(val2.size)
for (array <- val2) { for (array <- val2) {
out.writeInt(array.length) out.writeInt(array.length)
......
...@@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String ...@@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment() val workerEnv = pb.environment()
workerEnv.putAll(envVars) workerEnv.putAll(envVars)
daemon = pb.start() daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours // Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) { new Thread("stderr reader for " + pythonExec) {
...@@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String ...@@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
} }
} }
}.start() }.start()
val in = new DataInputStream(daemon.getInputStream)
daemonPort = in.readInt()
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch { } catch {
case e => { case e => {
stopDaemon() stopDaemon()
......
import os import os
import signal
import socket
import sys import sys
import traceback
import multiprocessing import multiprocessing
from ctypes import c_bool from ctypes import c_bool
from errno import EINTR, ECHILD from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main from pyspark.worker import main as worker_main
from pyspark.serializers import write_int from pyspark.serializers import write_int
...@@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code): ...@@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code):
def worker(listen_sock): def worker(listen_sock):
# Redirect stdout to stderr # Redirect stdout to stderr
os.dup2(2, 1) os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
# Manager sends SIGHUP to request termination of workers in the pool # Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args): def handle_sighup(*args):
assert should_exit() assert should_exit()
signal(SIGHUP, handle_sighup) signal.signal(SIGHUP, handle_sighup)
# Cleanup zombie children # Cleanup zombie children
def handle_sigchld(*args): def handle_sigchld(*args):
...@@ -51,7 +55,7 @@ def worker(listen_sock): ...@@ -51,7 +55,7 @@ def worker(listen_sock):
handle_sigchld() handle_sigchld()
elif err.errno != ECHILD: elif err.errno != ECHILD:
raise raise
signal(SIGCHLD, handle_sigchld) signal.signal(SIGCHLD, handle_sigchld)
# Handle clients # Handle clients
while not should_exit(): while not should_exit():
...@@ -70,19 +74,22 @@ def worker(listen_sock): ...@@ -70,19 +74,22 @@ def worker(listen_sock):
# never receives SIGCHLD unless a worker crashes. # never receives SIGCHLD unless a worker crashes.
if os.fork() == 0: if os.fork() == 0:
# Leave the worker pool # Leave the worker pool
signal(SIGHUP, SIG_DFL) signal.signal(SIGHUP, SIG_DFL)
listen_sock.close() listen_sock.close()
# Handle the client then exit # Read the socket using fdopen instead of socket.makefile() because the latter
sockfile = sock.makefile() # seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0 exit_code = 0
try: try:
worker_main(sockfile, sockfile) worker_main(infile, outfile)
except SystemExit as exc: except SystemExit as exc:
exit_code = exc.code exit_code = exc.code
finally: finally:
sockfile.close() outfile.flush()
sock.close() sock.close()
os._exit(compute_real_exit_code(exit_code)) os._exit(compute_real_exit_code(exit_code))
else: else:
sock.close() sock.close()
...@@ -92,7 +99,6 @@ def launch_worker(listen_sock): ...@@ -92,7 +99,6 @@ def launch_worker(listen_sock):
try: try:
worker(listen_sock) worker(listen_sock)
except Exception as err: except Exception as err:
import traceback
traceback.print_exc() traceback.print_exc()
os._exit(1) os._exit(1)
else: else:
...@@ -105,7 +111,7 @@ def manager(): ...@@ -105,7 +111,7 @@ def manager():
os.setpgid(0, 0) os.setpgid(0, 0)
# Create a listening socket on the AF_INET loopback interface # Create a listening socket on the AF_INET loopback interface
listen_sock = socket(AF_INET, SOCK_STREAM) listen_sock = socket.socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0)) listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname() listen_host, listen_port = listen_sock.getsockname()
...@@ -121,8 +127,8 @@ def manager(): ...@@ -121,8 +127,8 @@ def manager():
exit_flag.value = True exit_flag.value = True
# Gracefully exit on SIGTERM, don't die on SIGHUP # Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown()) signal.signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN) signal.signal(SIGHUP, SIG_IGN)
# Cleanup zombie children # Cleanup zombie children
def handle_sigchld(*args): def handle_sigchld(*args):
...@@ -133,7 +139,7 @@ def manager(): ...@@ -133,7 +139,7 @@ def manager():
except EnvironmentError as err: except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR): if err.errno not in (ECHILD, EINTR):
raise raise
signal(SIGCHLD, handle_sigchld) signal.signal(SIGCHLD, handle_sigchld)
# Initialization complete # Initialization complete
sys.stdout.close() sys.stdout.close()
...@@ -148,7 +154,7 @@ def manager(): ...@@ -148,7 +154,7 @@ def manager():
shutdown() shutdown()
raise raise
finally: finally:
signal(SIGTERM, SIG_DFL) signal.signal(SIGTERM, SIG_DFL)
exit_flag.value = True exit_flag.value = True
# Send SIGHUP to notify workers of shutdown # Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP) os.kill(0, SIGHUP)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment