Skip to content
Snippets Groups Projects
Commit 1381fc72 authored by Josh Rosen's avatar Josh Rosen
Browse files

Switch from MUTF8 to UTF8 in PySpark serializers.

This fixes SPARK-1043, a bug introduced in 0.9.0
where PySpark couldn't serialize strings > 64kB.

This fix was written by @tyro89 and @bouk in #512.
This commit squashes and rebases their pull request
in order to fix some merge conflicts.
parent 84670f27
No related branches found
No related tags found
No related merge requests found
......@@ -64,7 +64,7 @@ private[spark] class PythonRDD[T: ClassTag](
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
dataOut.writeUTF(SparkFiles.getRootDirectory)
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
......@@ -74,7 +74,9 @@ private[spark] class PythonRDD[T: ClassTag](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
pythonIncludes.foreach(dataOut.writeUTF)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
......@@ -228,7 +230,7 @@ private[spark] object PythonRDD {
}
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach { str =>
dataOut.writeUTF(str)
writeUTF(str, dataOut)
}
case pair: Tuple2[_, _] =>
pair._1 match {
......@@ -241,8 +243,8 @@ private[spark] object PythonRDD {
}
case stringPair: String =>
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
dataOut.writeUTF(pair._1)
dataOut.writeUTF(pair._2)
writeUTF(pair._1, dataOut)
writeUTF(pair._2, dataOut)
}
case other =>
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
......@@ -253,6 +255,12 @@ private[spark] object PythonRDD {
}
}
def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes("UTF-8")
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.api.python.PythonRDD
import java.io.{ByteArrayOutputStream, DataOutputStream}
class PythonRDDSuite extends FunSuite {
test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}
}
......@@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
......@@ -234,7 +234,7 @@ class SparkContext(object):
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
return RDD(self._jsc.textFile(name, minSplits), self,
MUTF8Deserializer())
UTF8Deserializer())
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
......
......@@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads
class MUTF8Deserializer(Serializer):
class UTF8Deserializer(Serializer):
"""
Deserializes streams written by Java's DataOutputStream.writeUTF().
Deserializes streams written by getBytes.
"""
def loads(self, stream):
length = struct.unpack('>H', stream.read(2))[0]
length = read_int(stream)
return stream.read(length).decode('utf8')
def load_stream(self, stream):
......
......@@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
pickleSer = PickleSerializer()
mutf8_deserializer = MUTF8Deserializer()
utf8_deserializer = UTF8Deserializer()
def report_times(outfile, boot, init, finish):
......@@ -51,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
spark_files_dir = mutf8_deserializer.loads(infile)
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
......@@ -66,7 +66,7 @@ def main(infile, outfile):
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = mutf8_deserializer.loads(infile)
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
command = pickleSer._read_with_length(infile)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment