diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index d3e206b3537f15596b114c605fd8b789874060d4..413c944a66d93a37aa2716afbf6b869803ce6d52 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (mapSideCombine) { @@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] for ((k, v) <- iter) { @@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * be set to true. */ def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } if (mapSideCombine) { def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v @@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), partitioner) @@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other1.asInstanceOf[RDD[(_, _)]], diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index b71021a08229a29f5650fc0d73b10bd72c035075..9d5b966e1e5a092b542a2621aa2e91c17a9b299a 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable { /** * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * + * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, + * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will + * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index d15c6f739663c10842a35ca2874bee4e1c20ae12..7e38583391337baf00d1ac935a225414ff74d895 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -417,6 +417,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValue() does not support arrays") + } // TODO: This should perhaps be distributed by default. def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { val map = new OLMap[T] @@ -445,6 +448,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => val map = new OLMap[T] while (iter.hasNext) { diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 3dadc7acec1771b034ca8ff6ae7b1d8657149866..f09b602a7b306385add2bb64c642ebd12dfa8346 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) } + + test("partitioning Java arrays should fail") { + sc = new SparkContext("local", "test") + val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrPairs: RDD[(Array[Int], Int)] = + sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + + assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + // We can't catch all usages of arrays, since they might occur inside other collections: + //assert(fails { arrPairs.distinct() }) + assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + } }