diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 97625b94a0e23feba44eb60eeed81cc07a4dca87..40d5066a93f4c4d942a104b670ecb6610eec5661 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1173,7 +1173,7 @@ test_that("group by, agg functions", { expect_equal(3, count(mean(gd))) expect_equal(3, count(max(gd))) - expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(30, collect(max(gd))[2, 2]) expect_equal(1, collect(count(gd))[1, 2]) mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a7bc288e3886131bb15f4695547cac8d8e8ddda0..90a6b5d9c0dda127e79ea2072eb48dbec347a01d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -403,10 +403,10 @@ class DataFrame(object): +---+-----+ |age| name| +---+-----+ - | 2|Alice| - | 2|Alice| | 5| Bob| | 5| Bob| + | 2|Alice| + | 2|Alice| +---+-----+ >>> data = data.repartition(7, "age") >>> data.show() @@ -552,7 +552,7 @@ class DataFrame(object): >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() - [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @@ -573,14 +573,14 @@ class DataFrame(object): One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() - [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() - [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] @@ -880,9 +880,9 @@ class DataFrame(object): >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() + >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(df.name).avg().collect() + >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] @@ -901,11 +901,11 @@ class DataFrame(object): +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -923,12 +923,12 @@ class DataFrame(object): | name| age|count| +-----+----+-----+ | null| 2| 1| - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| - | Bob|null| 1| | null| 5| 1| + | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 9ca303a974cd4aa934ba44c57c8c78c80fcd177f..ee734cb439287ae4f4a4c27b81ad41edb3088f24 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -74,11 +74,11 @@ class GroupedData(object): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() + >>> sorted(gdf.agg({"*": "count"}).collect()) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() + >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" @@ -96,7 +96,7 @@ class GroupedData(object): def count(self): """Counts the number of records for each group. - >>> df.groupBy(df.age).count().collect() + >>> sorted(df.groupBy(df.age).count().collect()) [Row(age=2, count=1), Row(age=5, count=1)] """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1bfe0ecb1e20b30a18803f39e87f2663df0eb942..d6e10c412ca1c1e56b3f87a41a0bb14ef0626a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -249,6 +249,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + /** + * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less + * than numPartitions) based on hashing expressions. + */ + def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 058d147c7d65d6595649a906e1fcfa0c771a2bff..3770883af1e2f078478726cf718a4cdf595a02ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -143,7 +143,13 @@ case class Exchange( val rdd = child.execute() val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -173,7 +179,9 @@ case class Exchange( position += 1 position } - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output) + row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index fff72872c13b15f2a52e614ead03c75bb4854057..fc77529b7db3278cc7ef5e2977697ac2083bc934 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} @@ -30,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} @@ -322,9 +321,12 @@ private[sql] class DynamicPartitionWriterContainer( spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) } - private def bucketIdExpression: Option[Expression] = for { - BucketSpec(numBuckets, _, _) <- bucketSpec - } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } // Expressions that given a partition key build a string like: col1=val/col2=val/... private def partitionStringExpression: Seq[Expression] = { @@ -341,12 +343,8 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def getBucketIdFromKey(key: InternalRow): Option[Int] = { - if (bucketSpec.isDefined) { - Some(key.getInt(partitionColumns.length)) - } else { - None - } + private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ => + key.getInt(partitionColumns.length) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 8e0b2dbca4a98e25120b240c792ab1a1fca85e5c..ac1607ba3521a6f6bfab8658a21fe98e72347fd2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -237,8 +237,8 @@ public class JavaDataFrameSuite { DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); - Assert.assertEquals("1", columnNames[1]); - Assert.assertEquals("2", columnNames[2]); + Assert.assertEquals("2", columnNames[1]); + Assert.assertEquals("1", columnNames[2]); Row[] rows = crosstab.collect(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 9f8db39e33d7ecf4483326257576ce145f2081d1..1a3df1b117b6837689c4fff6373b3c84d8025a7d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable { } }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); Dataset<String> flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction<Integer, String, String>() { @@ -202,7 +202,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() { @Override @@ -212,8 +212,8 @@ public class JavaDatasetSuite implements Serializable { }); Assert.assertEquals( - Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), - reduced.collectAsList()); + asSet(tuple2(1, "a"), tuple2(3, "foobar")), + toSet(reduced.collectAsList())); List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); @@ -245,7 +245,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList())); } @Test @@ -268,7 +268,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); } @Test @@ -290,9 +290,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("abc", "abc", "xyz"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - Assert.assertEquals( - Arrays.asList("abc", "xyz"), - sort(ds.distinct().collectAsList().toArray(new String[0]))); + Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); @@ -302,16 +300,23 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> unioned = ds.union(ds2); Assert.assertEquals( - Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), - sort(unioned.collectAsList().toArray(new String[0]))); + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + unioned.collectAsList()); Dataset<String> subtracted = ds.subtract(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); } - private <T extends Comparable<T>> List<T> sort(T[] data) { - Arrays.sort(data); - return Arrays.asList(data); + private <T> Set<T> toSet(List<T> records) { + Set<T> set = new HashSet<T>(); + for (T record : records) { + set.add(record); + } + return set; + } + + private <T> Set<T> asSet(T... records) { + return toSet(Arrays.asList(records)); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 983dfbdedeefe73d8ea94bd0e85f66320b31f410..d6c140dfea9ed0dfaa181a57f1b18d34d27431c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1083,17 +1083,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Walk each partition and verify that it is sorted descending and does not contain all // the values. df4.rdd.foreachPartition { p => - var previousValue: Int = -1 - var allSequential: Boolean = true - p.foreach { r => - val v: Int = r.getInt(1) - if (previousValue != -1) { - if (previousValue < v) throw new SparkException("Partition is not ordered.") - if (v + 1 != previousValue) allSequential = false + // Skip empty partition + if (p.hasNext) { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v } - previousValue = v + if (allSequential) throw new SparkException("Partition should not be globally ordered") } - if (allSequential) throw new SparkException("Partition should not be globally ordered") } // Distribute and order by with multiple order bys diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 693f5aea2d0155e18c8ffbb8873851b1902850e8..d7b86e381108ed7b7444860c589f2d9ddddda493 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -456,8 +456,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSeq == - Seq((KryoData(1), 1L), (KryoData(2), 1L))) + assert(ds.groupBy(p => p).count().collect().toSet == + Set((KryoData(1), 1L), (KryoData(2), 1L))) } test("Kryo encoder self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5de0979606b88c001a355f09c6e52097b2462e84..03d67c4e91f7f8f983c14fea30902af74441def8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -806,7 +806,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - sql("SELECT DISTINCT n FROM lowerCaseData") + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC") .limit(2) .registerTempTable("subset2") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index b718b7cefb2a437c97fc66b519ade9e0ed380e7e..3ea9826544edb5c71ec376faa8c5d3daaea52a77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -98,11 +98,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() - val getHashCode = - UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) + val getHashCode = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, + qe.analyzed.output) for (row <- rows) { - val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) + val actualBucketId = getHashCode(row).getInt(0) assert(actualBucketId == bucketId) } }