Skip to content
Snippets Groups Projects
Commit 3776f2f2 authored by Bouke van der Bijl's avatar Bouke van der Bijl Committed by Patrick Wendell
Browse files

Add Python includes to path before depickling broadcast values

This fixes https://issues.apache.org/jira/browse/SPARK-1731 by adding the Python includes to the PYTHONPATH before depickling the broadcast values

@airhorns

Author: Bouke van der Bijl <boukevanderbijl@gmail.com>

Closes #656 from bouk/python-includes-before-broadcast and squashes the following commits:

7b0dfe4 [Bouke van der Bijl] Add Python includes to path before depickling broadcast values
parent c05d11bb
No related branches found
No related tags found
No related merge requests found
......@@ -179,6 +179,11 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
......@@ -186,11 +191,6 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
......
......@@ -56,13 +56,6 @@ def main(infile, outfile):
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = pickleSer._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
......@@ -70,6 +63,13 @@ def main(infile, outfile):
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = pickleSer._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
command = pickleSer._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment