Skip to content
Snippets Groups Projects
Commit c593f632 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #348 from JoshRosen/spark-597

Raise exception when hashing Java arrays (SPARK-597)
parents 3f74f729 f8039539
No related branches found
No related tags found
No related merge requests found
...@@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeCombiners: (C, C) => C, mergeCombiners: (C, C) => C,
partitioner: Partitioner, partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = { mapSideCombine: Boolean = true): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
}
if (partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
val aggregator = val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) { if (mapSideCombine) {
...@@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* before sending results to a reducer, similarly to a "combiner" in MapReduce. * before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/ */
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
if (getKeyClass().isArray) {
throw new SparkException("reduceByKeyLocally() does not support array keys")
}
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V] val map = new JHashMap[K, V]
for ((k, v) <- iter) { for ((k, v) <- iter) {
...@@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* be set to true. * be set to true.
*/ */
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
}
if (partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
if (mapSideCombine) { if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v) def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
...@@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list of values for that key in `this` as well as `other`. * list of values for that key in `this` as well as `other`.
*/ */
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K]( val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
partitioner) partitioner)
...@@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
*/ */
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K]( val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]], other1.asInstanceOf[RDD[(_, _)]],
......
...@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable { ...@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable {
/** /**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
*
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
*/ */
class HashPartitioner(partitions: Int) extends Partitioner { class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions def numPartitions = partitions
......
...@@ -417,6 +417,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -417,6 +417,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* combine step happens locally on the master, equivalent to running a single reduce task. * combine step happens locally on the master, equivalent to running a single reduce task.
*/ */
def countByValue(): Map[T, Long] = { def countByValue(): Map[T, Long] = {
if (elementClassManifest.erasure.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
// TODO: This should perhaps be distributed by default. // TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T] val map = new OLMap[T]
...@@ -445,6 +448,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -445,6 +448,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
timeout: Long, timeout: Long,
confidence: Double = 0.95 confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = { ): PartialResult[Map[T, BoundedDouble]] = {
if (elementClassManifest.erasure.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T] val map = new OLMap[T]
while (iter.hasNext) { while (iter.hasNext) {
......
...@@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { ...@@ -107,4 +107,25 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
} }
test("partitioning Java arrays should fail") {
sc = new SparkContext("local", "test")
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
val arrPairs: RDD[(Array[Int], Int)] =
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
// We can't catch all usages of arrays, since they might occur inside other collections:
//assert(fails { arrPairs.distinct() })
assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment