Skip to content
Snippets Groups Projects
Commit 3a9d82cc authored by Andrew Ash's avatar Andrew Ash Committed by Patrick Wendell
Browse files

Merge pull request #506 from ash211/intersection. Closes #506.

SPARK-1062 Add rdd.intersection(otherRdd) method

Author: Andrew Ash <andrew@andrewash.com>

== Merge branch commits ==

commit 5d9982b171b9572649e9828f37ef0b43f0242912
Author: Andrew Ash <andrew@andrewash.com>
Date:   Thu Feb 6 18:11:45 2014 -0800

    Minor fixes

    - style: (v,null) => (v, null)
    - mention the shuffle in Javadoc

commit b86d02f14e810902719cef893cf6bfa18ff9acb0
Author: Andrew Ash <andrew@andrewash.com>
Date:   Sun Feb 2 13:17:40 2014 -0800

    Overload .intersection() for numPartitions and custom Partitioner

commit bcaa34911fcc6bb5bc5e4f9fe46d1df73cb71c09
Author: Andrew Ash <andrew@andrewash.com>
Date:   Sun Feb 2 13:05:40 2014 -0800

    Better naming of parameters in intersection's filter

commit b10a6af2d793ec6e9a06c798007fac3f6b860d89
Author: Andrew Ash <andrew@andrewash.com>
Date:   Sat Jan 25 23:06:26 2014 -0800

    Follow spark code format conventions of tab => 2 spaces

commit 965256e4304cca514bb36a1a36087711dec535ec
Author: Andrew Ash <andrew@andrewash.com>
Date:   Fri Jan 24 00:28:01 2014 -0800

    Add rdd.intersection(otherRdd) method
parent 1896c6e7
No related branches found
No related tags found
No related merge requests found
......@@ -393,6 +393,43 @@ abstract class RDD[T: ClassTag](
*/
def ++(other: RDD[T]): RDD[T] = this.union(other)
/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*
* Note that this method performs a shuffle internally.
*/
def intersection(other: RDD[T]): RDD[T] =
this.map(v => (v, null)).cogroup(other.map(v => (v, null)))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*
* Note that this method performs a shuffle internally.
*
* @param partitioner Partitioner to use for the resulting RDD
*/
def intersection(other: RDD[T], partitioner: Partitioner): RDD[T] =
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner)
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did. Performs a hash partition across the cluster
*
* Note that this method performs a shuffle internally.
*
* @param numPartitions How many partitions to use in the resulting RDD
*/
def intersection(other: RDD[T], numPartitions: Int): RDD[T] =
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
......
......@@ -373,8 +373,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val prng42 = new Random(42)
val prng43 = new Random(43)
Array(1, 2, 3, 4, 5, 6).filter{i =>
if (i < 4) 0 == prng42.nextInt(3)
else 0 == prng43.nextInt(3)}
if (i < 4) 0 == prng42.nextInt(3)
else 0 == prng43.nextInt(3)}
}
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
......@@ -506,4 +506,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
}
}
test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
val intersection = Array(2, 4, 6, 8, 10)
// intersection is commutative
assert(all.intersection(evens).collect.sorted === intersection)
assert(evens.intersection(all).collect.sorted === intersection)
}
test("intersection strips duplicates in an input") {
val a = sc.parallelize(Seq(1,2,3,3))
val b = sc.parallelize(Seq(1,1,2,3))
val intersection = Array(1,2,3)
assert(a.intersection(b).collect.sorted === intersection)
assert(b.intersection(a).collect.sorted === intersection)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment