From 2ccf3b665280bf5b0919e3801d028126cb070dbd Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Sun, 28 Oct 2012 22:30:28 -0700
Subject: [PATCH] Fix PySpark hash partitioning bug.

A Java array's hashCode is based on its object
identify, not its elements, so this was causing
serialized keys to be hashed incorrectly.

This commit adds a PySpark-specific workaround
and adds more tests.
---
 .../spark/api/python/PythonPartitioner.scala  | 41 +++++++++++++++++++
 .../scala/spark/api/python/PythonRDD.scala    | 10 ++---
 pyspark/pyspark/rdd.py                        | 12 ++++--
 3 files changed, 54 insertions(+), 9 deletions(-)
 create mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala

diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..ef9f808fb2
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,41 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ */
+class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+
+  override def getPartition(key: Any): Int = {
+    if (key == null) {
+      return 0
+    }
+    else {
+      val hashCode = {
+        if (key.isInstanceOf[Array[Byte]]) {
+          System.err.println("Dumping a byte array!" +           Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+          )
+          Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+        }
+        else
+          key.hashCode()
+      }
+      val mod = hashCode % numPartitions
+      if (mod < 0) {
+        mod + numPartitions
+      } else {
+        mod // Guard against negative hash codes
+      }
+    }
+  }
+
+  override def equals(other: Any): Boolean = other match {
+    case h: PythonPartitioner =>
+      h.numPartitions == numPartitions
+    case _ =>
+      false
+  }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index a593e53efd..50094d6b0f 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -179,14 +179,12 @@ object PythonRDD {
     val dOut = new DataOutputStream(baos);
     if (elem.isInstanceOf[Array[Byte]]) {
       elem.asInstanceOf[Array[Byte]]
-    } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[_, _]]
-      val t1 = t._1.asInstanceOf[Array[Byte]]
-      val t2 = t._2.asInstanceOf[Array[Byte]]
+    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
       dOut.writeByte(Pickle.PROTO)
       dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t1))
-      dOut.write(PythonRDD.stripPickle(t2))
+      dOut.write(PythonRDD.stripPickle(t._1))
+      dOut.write(PythonRDD.stripPickle(t._2))
       dOut.writeByte(Pickle.TUPLE2)
       dOut.writeByte(Pickle.STOP)
       baos.toByteArray()
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index e4878c08ba..85a24c6854 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -310,6 +310,12 @@ class RDD(object):
         return python_right_outer_join(self, other, numSplits)
 
     def partitionBy(self, numSplits, hashFunc=hash):
+        """
+        >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+        >>> sets = pairs.partitionBy(2).glom().collect()
+        >>> set(sets[0]).intersection(set(sets[1]))
+        set([])
+        """
         if numSplits is None:
             numSplits = self.ctx.defaultParallelism
         def add_shuffle_key(iterator):
@@ -319,7 +325,7 @@ class RDD(object):
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
+        partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
         jrdd = pairRDD.partitionBy(partitioner)
         jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
         return RDD(jrdd, self.ctx)
@@ -391,7 +397,7 @@ class RDD(object):
         """
         >>> x = sc.parallelize([("a", 1), ("b", 4)])
         >>> y = sc.parallelize([("a", 2)])
-        >>> x.cogroup(y).collect()
+        >>> sorted(x.cogroup(y).collect())
         [('a', ([1], [2])), ('b', ([4], []))]
         """
         return python_cogroup(self, other, numSplits)
@@ -462,7 +468,7 @@ def _test():
     import doctest
     from pyspark.context import SparkContext
     globs = globals().copy()
-    globs['sc'] = SparkContext('local', 'PythonTest')
+    globs['sc'] = SparkContext('local[4]', 'PythonTest')
     doctest.testmod(globs=globs)
     globs['sc'].stop()
 
-- 
GitLab