Skip to content
Snippets Groups Projects
Commit 2ccf3b66 authored by Josh Rosen's avatar Josh Rosen
Browse files

Fix PySpark hash partitioning bug.

A Java array's hashCode is based on its object
identify, not its elements, so this was causing
serialized keys to be hashed incorrectly.

This commit adds a PySpark-specific workaround
and adds more tests.
parent 7859879a
No related branches found
No related tags found
No related merge requests found
package spark.api.python
import spark.Partitioner
import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*/
class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
return 0
}
else {
val hashCode = {
if (key.isInstanceOf[Array[Byte]]) {
System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]])
)
Arrays.hashCode(key.asInstanceOf[Array[Byte]])
}
else
key.hashCode()
}
val mod = hashCode % numPartitions
if (mod < 0) {
mod + numPartitions
} else {
mod // Guard against negative hash codes
}
}
}
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
}
......@@ -179,14 +179,12 @@ object PythonRDD {
val dOut = new DataOutputStream(baos);
if (elem.isInstanceOf[Array[Byte]]) {
elem.asInstanceOf[Array[Byte]]
} else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
val t = elem.asInstanceOf[scala.Tuple2[_, _]]
val t1 = t._1.asInstanceOf[Array[Byte]]
val t2 = t._2.asInstanceOf[Array[Byte]]
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t1))
dOut.write(PythonRDD.stripPickle(t2))
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
baos.toByteArray()
......
......@@ -310,6 +310,12 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
def partitionBy(self, numSplits, hashFunc=hash):
"""
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
>>> sets = pairs.partitionBy(2).glom().collect()
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator):
......@@ -319,7 +325,7 @@ class RDD(object):
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
......@@ -391,7 +397,7 @@ class RDD(object):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> x.cogroup(y).collect()
>>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numSplits)
......@@ -462,7 +468,7 @@ def _test():
import doctest
from pyspark.context import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local', 'PythonTest')
globs['sc'] = SparkContext('local[4]', 'PythonTest')
doctest.testmod(globs=globs)
globs['sc'].stop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment