Skip to content
Snippets Groups Projects
Commit c79a6078 authored by Jey Kottalam's avatar Jey Kottalam
Browse files

Prefork Python worker processes

parent 40afe0d2
No related branches found
No related tags found
No related merge requests found
package spark package spark
import collection.mutable
import serializer.Serializer
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider import akka.remote.RemoteActorRefProvider
...@@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster ...@@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager import spark.network.ConnectionManager
import spark.serializer.{Serializer, SerializerManager} import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils import spark.util.AkkaUtils
import spark.api.python.PythonWorker
/** /**
...@@ -37,6 +41,8 @@ class SparkEnv ( ...@@ -37,6 +41,8 @@ class SparkEnv (
// If executorId is NOT found, return defaultHostPort // If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) { var executorIdToHostPort: Option[(String, String) => String]) {
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]()
def stop() { def stop() {
httpFileServer.stop() httpFileServer.stop()
mapOutputTracker.stop() mapOutputTracker.stop()
...@@ -50,6 +56,11 @@ class SparkEnv ( ...@@ -50,6 +56,11 @@ class SparkEnv (
actorSystem.awaitTermination() actorSystem.awaitTermination()
} }
def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = {
synchronized {
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars))
}
}
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get val env = SparkEnv.get
......
...@@ -2,10 +2,9 @@ package spark.api.python ...@@ -2,10 +2,9 @@ package spark.api.python
import java.io._ import java.io._
import java.net._ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections} import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast import spark.broadcast.Broadcast
...@@ -16,7 +15,7 @@ import spark.rdd.PipedRDD ...@@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest]( private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T], parent: RDD[T],
command: Seq[String], command: Seq[String],
envVars: java.util.Map[String, String], envVars: JMap[String, String],
preservePartitoning: Boolean, preservePartitoning: Boolean,
pythonExec: String, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]], broadcastVars: JList[Broadcast[Array[Byte]]],
...@@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String, preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]], broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) = accumulator: Accumulator[JList[Array[Byte]]]) =
...@@ -36,36 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -36,36 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
for ((variable, value) <- envVars) {
currentEnvVars.put(variable, value)
}
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis val startTime = System.currentTimeMillis
val proc = pb.start() val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create
val env = SparkEnv.get val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
new Thread("stderr reader for " + pythonExec) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
// Start a thread to feed the process input from our parent's iterator // Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) { new Thread("stdin writer for " + pythonExec) {
override def run() { override def run() {
SparkEnv.set(env) SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream) val out = new PrintWriter(worker.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream) val dOut = new DataOutputStream(worker.getOutputStream)
// Partition index // Partition index
dOut.writeInt(split.index) dOut.writeInt(split.index)
// sparkFilesDir // sparkFilesDir
...@@ -89,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -89,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
} }
dOut.flush() dOut.flush()
out.flush() out.flush()
proc.getOutputStream.close() worker.shutdownOutput()
} }
}.start() }.start()
// Return an iterator that read lines from the process's stdout // Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream) val stream = new DataInputStream(worker.getInputStream)
return new Iterator[Array[Byte]] { return new Iterator[Array[Byte]] {
def next(): Array[Byte] = { def next(): Array[Byte] = {
val obj = _nextObj val obj = _nextObj
_nextObj = read() if (hasNext) {
// FIXME: can deadlock if worker is waiting for us to
// respond to current message (currently irrelevant because
// output is shutdown before we read any input)
_nextObj = read()
}
obj obj
} }
...@@ -110,7 +96,7 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -110,7 +96,7 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj) stream.readFully(obj)
obj obj
case -3 => case -3 =>
// Timing data from child // Timing data from worker
val bootTime = stream.readLong() val bootTime = stream.readLong()
val initTime = stream.readLong() val initTime = stream.readLong()
val finishTime = stream.readLong() val finishTime = stream.readLong()
...@@ -127,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -127,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj) stream.readFully(obj)
throw new PythonException(new String(obj)) throw new PythonException(new String(obj))
case -1 => case -1 =>
// We've finished the data section of the output, but we can still read some // We've finished the data section of the output, but we can still
// accumulator updates; let's do that, breaking when we get EOFException // read some accumulator updates; let's do that, breaking when we
while (true) { // get a negative length record.
val len2 = stream.readInt() var len2 = stream.readInt
while (len2 >= 0) {
val update = new Array[Byte](len2) val update = new Array[Byte](len2)
stream.readFully(update) stream.readFully(update)
accumulator += Collections.singletonList(update) accumulator += Collections.singletonList(update)
len2 = stream.readInt
} }
new Array[Byte](0) new Array[Byte](0)
} }
} catch { } catch {
case eof: EOFException => { case eof: EOFException => {
val exitStatus = proc.waitFor() throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
new Array[Byte](0)
} }
case e => throw e case e => throw e
} }
...@@ -171,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends ...@@ -171,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) = override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map { prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
} }
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
} }
...@@ -227,7 +211,7 @@ private[spark] object PythonRDD { ...@@ -227,7 +211,7 @@ private[spark] object PythonRDD {
dOut.write(s) dOut.write(s)
dOut.writeByte(Pickle.STOP) dOut.writeByte(Pickle.STOP)
} else { } else {
throw new Exception("Unexpected RDD type") throw new SparkException("Unexpected RDD type")
} }
} }
......
package spark.api.python
import java.io.DataInputStream
import java.net.{Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
import spark._
private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String])
extends Logging {
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon
// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
} catch {
case exc: SocketException => {
logWarning("Python daemon unexpectedly quit, attempting to restart")
stopDaemon
startDaemon
new Socket(daemonHost, daemonPort)
}
case e => throw e
}
}
}
private def startDaemon() {
synchronized {
// Is it already running?
if (daemon != null) {
return
}
try {
// Create and start the daemon
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
override def run() {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
val in = daemon.getErrorStream
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}.start()
} catch {
case e => {
stopDaemon
throw e
}
}
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
// detect our disappearance.
}
}
private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy
}
daemon = null
daemonPort = 0
}
}
}
import os
import sys
import multiprocessing
from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
try:
POOLSIZE = multiprocessing.cpu_count()
except NotImplementedError:
POOLSIZE = 4
should_exit = False
def worker(listen_sock):
# Redirect stdout to stderr
os.dup2(2, 1)
# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(signum, frame):
global should_exit
should_exit = True
signal(SIGHUP, handle_sighup)
while not should_exit:
# Wait until a client arrives or we have to exit
sock = None
while not should_exit and sock is None:
try:
sock, addr = listen_sock.accept()
except EnvironmentError as err:
if err.errno != EINTR:
raise
if sock is not None:
# Fork a child to handle the client
if os.fork() == 0:
# Leave the worker pool
signal(SIGHUP, SIG_DFL)
listen_sock.close()
# Handle the client then exit
sockfile = sock.makefile()
worker_main(sockfile, sockfile)
sockfile.close()
sock.close()
os._exit(0)
else:
sock.close()
assert should_exit
os._exit(0)
def manager():
# Create a new process group to corral our children
os.setpgid(0, 0)
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)
# Launch initial worker pool
for idx in range(POOLSIZE):
if os.fork() == 0:
worker(listen_sock)
raise RuntimeError("worker() unexpectedly returned")
listen_sock.close()
def shutdown():
global should_exit
os.kill(0, SIGHUP)
should_exit = True
# Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN)
# Cleanup zombie children
def handle_sigchld(signum, frame):
try:
pid, status = os.waitpid(0, os.WNOHANG)
if (pid, status) != (0, 0) and not should_exit:
raise RuntimeError("pool member crashed: %s, %s" % (pid, status))
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
signal(SIGCHLD, handle_sigchld)
# Initialization complete
sys.stdout.close()
while not should_exit:
try:
# Spark tells us to exit by closing stdin
if sys.stdin.read() == '':
shutdown()
except EnvironmentError as err:
if err.errno != EINTR:
shutdown()
raise
if __name__ == '__main__':
manager()
""" """
Worker that receives input from Piped RDD. Worker that receives input from Piped RDD.
""" """
import time
preboot_time = time.time()
import os import os
import sys import sys
import time
import traceback import traceback
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
...@@ -17,57 +16,55 @@ from pyspark.serializers import write_with_length, read_with_length, write_int, ...@@ -17,57 +16,55 @@ from pyspark.serializers import write_with_length, read_with_length, write_int,
read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions. def load_obj(infile):
old_stdout = os.fdopen(os.dup(1), 'w') return load_pickle(standard_b64decode(infile.readline().strip()))
os.dup2(2, 1)
def load_obj():
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def report_times(preboot, boot, init, finish): def report_times(outfile, boot, init, finish):
write_int(-3, old_stdout) write_int(-3, outfile)
write_long(1000 * preboot, old_stdout) write_long(1000 * boot, outfile)
write_long(1000 * boot, old_stdout) write_long(1000 * init, outfile)
write_long(1000 * init, old_stdout) write_long(1000 * finish, outfile)
write_long(1000 * finish, old_stdout)
def main(): def main(infile, outfile):
boot_time = time.time() boot_time = time.time()
split_index = read_int(sys.stdin) split_index = read_int(infile)
spark_files_dir = load_pickle(read_with_length(sys.stdin)) spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir) sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin) num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin) bid = read_long(infile)
value = read_with_length(sys.stdin) value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
func = load_obj() func = load_obj(infile)
bypassSerializer = load_obj() bypassSerializer = load_obj(infile)
if bypassSerializer: if bypassSerializer:
dumps = lambda x: x dumps = lambda x: x
else: else:
dumps = dump_pickle dumps = dump_pickle
init_time = time.time() init_time = time.time()
iterator = read_from_pickle_file(sys.stdin) iterator = read_from_pickle_file(infile)
try: try:
for obj in func(split_index, iterator): for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout) write_with_length(dumps(obj), outfile)
except Exception as e: except Exception as e:
write_int(-2, old_stdout) write_int(-2, outfile)
write_with_length(traceback.format_exc(), old_stdout) write_with_length(traceback.format_exc(), outfile)
sys.exit(-1) raise
finish_time = time.time() finish_time = time.time()
report_times(preboot_time, boot_time, init_time, finish_time) report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output # Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout) write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items(): for aid, accum in _accumulatorRegistry.items():
write_with_length(dump_pickle((aid, accum._value)), old_stdout) write_with_length(dump_pickle((aid, accum._value)), outfile)
write_int(-1, outfile)
if __name__ == '__main__': if __name__ == '__main__':
main() # Redirect stdout to stderr so that users must return values from functions.
old_stdout = os.fdopen(os.dup(1), 'w')
os.dup2(2, 1)
main(sys.stdin, old_stdout)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment