From b92d16b114fd49e881d09e7974ad57b2a0df2906 Mon Sep 17 00:00:00 2001 From: Andrew Ash <andrew@andrewash.com> Date: Tue, 17 Jun 2014 11:47:48 -0700 Subject: [PATCH] SPARK-1063 Add .sortBy(f) method on RDD This never got merged from the apache/incubator-spark repo (which is now deleted) but there had been several rounds of code review on this PR there. I think this is ready for merging. Author: Andrew Ash <andrew@andrewash.com> This patch had conflicts when merged, resolved by Committer: Reynold Xin <rxin@apache.org> Closes #369 from ash211/sortby and squashes the following commits: d09147a [Andrew Ash] Fix Ordering import 43d0a53 [Andrew Ash] Fix missing .collect() 29a54ed [Andrew Ash] Re-enable test by converting to a closure 5a95348 [Andrew Ash] Add license for RDDSuiteUtils 64ed6e3 [Andrew Ash] Remove leaked diff d4de69a [Andrew Ash] Remove scar tissue 63638b5 [Andrew Ash] Add Python version of .sortBy() 45e0fde [Andrew Ash] Add Java version of .sortBy() adf84c5 [Andrew Ash] Re-indent to keep line lengths under 100 chars 9d9b9d8 [Andrew Ash] Use parentheses on .collect() calls 0457b69 [Andrew Ash] Ignore failing test 99f0baf [Andrew Ash] Merge branch 'master' into sortby 222ae97 [Andrew Ash] Try moving Ordering objects out to a different class 3fd0dd3 [Andrew Ash] Add (failing) test for sortByKey with explicit Ordering b8b5bbc [Andrew Ash] Align remove extra spaces that were used to align ='s in test code 8c53298 [Andrew Ash] Actually use ascending and numPartitions parameters 381eef2 [Andrew Ash] Correct silly typo 7db3e84 [Andrew Ash] Support ascending and numPartitions params in sortBy() 0f685fd [Andrew Ash] Merge remote-tracking branch 'origin/master' into sortby ca4490d [Andrew Ash] Add .sortBy(f) method on RDD --- .../org/apache/spark/api/java/JavaRDD.scala | 16 +++++ .../main/scala/org/apache/spark/rdd/RDD.scala | 12 ++++ .../java/org/apache/spark/JavaAPISuite.java | 33 +++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 59 +++++++++++++++++-- .../org/apache/spark/rdd/RDDSuiteUtils.scala | 31 ++++++++++ python/pyspark/rdd.py | 12 ++++ 6 files changed, 159 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 23d1371079..86fb374bef 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,10 +17,13 @@ package org.apache.spark.api.java +import java.util.Comparator + import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -172,6 +175,19 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) rdd.setName(name) this } + + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.call(x) + import com.google.common.collect.Ordering // shadows scala.math.Ordering + implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] + implicit val ctag: ClassTag[S] = fakeClassTag + wrapRDD(rdd.sortBy(fn, ascending, numPartitions)) + } + } object JavaRDD { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 27cc60d775..cf915b870e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -442,6 +442,18 @@ abstract class RDD[T: ClassTag]( */ def ++(other: RDD[T]): RDD[T] = this.union(other) + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[K]( + f: (T) ⇒ K, + ascending: Boolean = true, + numPartitions: Int = this.partitions.size) + (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = + this.keyBy[K](f) + .sortByKey(ascending, numPartitions) + .values + /** * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ef41bfb88d..e46298c6a9 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -180,6 +180,39 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2)); } + @Test + public void sortBy() { + List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>(); + pairs.add(new Tuple2<Integer, Integer>(0, 4)); + pairs.add(new Tuple2<Integer, Integer>(3, 2)); + pairs.add(new Tuple2<Integer, Integer>(-1, 1)); + + JavaRDD<Tuple2<Integer, Integer>> rdd = sc.parallelize(pairs); + + // compare on first value + JavaRDD<Tuple2<Integer, Integer>> sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() { + public Integer call(Tuple2<Integer, Integer> t) throws Exception { + return t._1(); + } + }, true, 2); + + Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first()); + List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2)); + + // compare on second value + sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() { + public Integer call(Tuple2<Integer, Integer> t) throws Exception { + return t._2(); + } + }, true, 2); + Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first()); + sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(2)); + } + @Test public void foreach() { final Accumulator<Integer> accum = sc.accumulator(0); diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e94a1e76d4..0e5625b764 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils +import org.apache.spark.rdd.RDDSuiteUtils._ + class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { @@ -585,14 +587,63 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sortByKey") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val col1 = Array("4|60|C", "5|50|A", "6|40|B") + val col2 = Array("6|40|B", "5|50|A", "4|60|C") + val col3 = Array("5|50|A", "6|40|B", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0)).collect() === col1) + assert(data.sortBy(_.split("\\|")(1)).collect() === col2) + assert(data.sortBy(_.split("\\|")(2)).collect() === col3) + } + + test("sortByKey ascending parameter") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val asc = Array("4|60|C", "5|50|A", "6|40|B") + val desc = Array("6|40|B", "5|50|A", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0), true).collect() === asc) + assert(data.sortBy(_.split("\\|")(0), false).collect() === desc) + } + + test("sortByKey with explicit ordering") { + val data = sc.parallelize(Seq("Bob|Smith|50", + "Jane|Smith|40", + "Thomas|Williams|30", + "Karen|Williams|60")) + + val ageOrdered = Array("Thomas|Williams|30", + "Jane|Smith|40", + "Bob|Smith|50", + "Karen|Williams|60") + + // last name, then first name + val nameOrdered = Array("Bob|Smith|50", + "Jane|Smith|40", + "Karen|Williams|60", + "Thomas|Williams|30") + + val parse = (s: String) => { + val split = s.split("\\|") + Person(split(0), split(1), split(2).toInt) + } + + import scala.reflect.classTag + assert(data.sortBy(parse, true, 2)(AgeOrdering, classTag[Person]).collect() === ageOrdered) + assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) val intersection = Array(2, 4, 6, 8, 10) // intersection is commutative - assert(all.intersection(evens).collect.sorted === intersection) - assert(evens.intersection(all).collect.sorted === intersection) + assert(all.intersection(evens).collect().sorted === intersection) + assert(evens.intersection(all).collect().sorted === intersection) } test("intersection strips duplicates in an input") { @@ -600,8 +651,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { val b = sc.parallelize(Seq(1,1,2,3)) val intersection = Array(1,2,3) - assert(a.intersection(b).collect.sorted === intersection) - assert(b.intersection(a).collect.sorted === intersection) + assert(a.intersection(b).collect().sorted === intersection) + assert(b.intersection(a).collect().sorted === intersection) } test("zipWithIndex") { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala new file mode 100644 index 0000000000..4762fc1785 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +object RDDSuiteUtils { + case class Person(first: String, last: String, age: Int) + + object AgeOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = a.age compare b.age + } + + object NameOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = + implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) + } +} diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index bb4d035edc..65f63153cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -549,6 +549,18 @@ class RDD(object): .mapPartitions(mapFunc,preservesPartitioning=True) .flatMap(lambda x: x, preservesPartitioning=True)) + def sortBy(self, keyfunc, ascending=True, numPartitions=None): + """ + Sorts this RDD by the given keyfunc + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect() + [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + """ + return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values() + def glom(self): """ Return an RDD created by coalescing all elements within each partition -- GitLab