Skip to content
Snippets Groups Projects
Commit 3446d5c8 authored by Patrick Wendell's avatar Patrick Wendell
Browse files

SPARK-673: Capture and re-throw Python exceptions

This patch alters the Python <-> executor protocol to pass on
exception data when they occur in user Python code.
parent 55327a28
No related branches found
No related tags found
No related merge requests found
...@@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = { private def read(): Array[Byte] = {
try { try {
val length = stream.readInt() stream.readInt() match {
if (length != -1) { case length if length > 0 => {
val obj = new Array[Byte](length) val obj = new Array[Byte](length)
stream.readFully(obj) stream.readFully(obj)
obj obj
} else {
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
} }
new Array[Byte](0) case -2 => {
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
}
case -1 => {
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
new Array[Byte](0)
}
} }
} catch { } catch {
case eof: EOFException => { case eof: EOFException => {
...@@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String) extends Exception(msg)
/** /**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations. * This is used by PySpark's shuffle operations.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Worker that receives input from Piped RDD. Worker that receives input from Piped RDD.
""" """
import sys import sys
import traceback
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module. # copy_reg module.
...@@ -40,8 +41,13 @@ def main(): ...@@ -40,8 +41,13 @@ def main():
else: else:
dumps = dump_pickle dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin) iterator = read_from_pickle_file(sys.stdin)
for obj in func(split_index, iterator): try:
write_with_length(dumps(obj), old_stdout) for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout)
except Exception as e:
write_int(-2, old_stdout)
write_with_length(traceback.format_exc(), old_stdout)
sys.exit(-1)
# Mark the beginning of the accumulators section of the output # Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout) write_int(-1, old_stdout)
for aid, accum in _accumulatorRegistry.items(): for aid, accum in _accumulatorRegistry.items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment