diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 57da85fa84f9914b5fa742278c1d23a98de9bd4f..deb2c24d0f16e93f73688d8ba2481c5fde533274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -69,15 +69,18 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until - * the size of a post-shuffle partition is equal or greater than the target size. + * adding another pre-shuffle partition would cause the size of a post-shuffle partition to be + * greater than the target size. + * * For example, we have two stages with the following pre-shuffle partition size statistics: * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] * assuming the target input size is 128 MB, we will have three post-shuffle partitions, * which are: - * - post-shuffle partition 0: pre-shuffle partition 0 and 1 - * - post-shuffle partition 1: pre-shuffle partition 2 - * - post-shuffle partition 2: pre-shuffle partition 3 and 4 + * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB) + * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB) + * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MB) + * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) */ class ExchangeCoordinator( numExchanges: Int, @@ -164,25 +167,20 @@ class ExchangeCoordinator( while (i < numPreShufflePartitions) { // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. // Then, we add the total size to postShuffleInputSize. + var nextShuffleInputSize = 0L var j = 0 while (j < mapOutputStatistics.length) { - postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + nextShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) j += 1 } - // If the current postShuffleInputSize is equal or greater than the - // targetPostShuffleInputSize, We need to add a new element in partitionStartIndices. - if (postShuffleInputSize >= targetPostShuffleInputSize) { - if (i < numPreShufflePartitions - 1) { - // Next start index. - partitionStartIndices += i + 1 - } else { - // This is the last element. So, we do not need to append the next start index to - // partitionStartIndices. - } + // If including the nextShuffleInputSize would exceed the target partition size, then start a + // new partition. + if (i > 0 && postShuffleInputSize + nextShuffleInputSize > targetPostShuffleInputSize) { + partitionStartIndices += i // reset postShuffleInputSize. - postShuffleInputSize = 0L - } + postShuffleInputSize = nextShuffleInputSize + } else postShuffleInputSize += nextShuffleInputSize i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 2803b624624176af86bf0814280e61fc3a7c0d16..06bce9a2400e741d121f6d3b0713c9f3f9553c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -85,7 +85,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { { // There are a few large pre-shuffle partitions. val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) - val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) } @@ -146,7 +146,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // 2 post-shuffle partition are needed. val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 3) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) checkEstimation( coordinator, Array(bytesByPartitionId1, bytesByPartitionId2), @@ -154,10 +154,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } { - // 2 post-shuffle partition are needed. + // 4 post-shuffle partition are needed. val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 2) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) checkEstimation( coordinator, Array(bytesByPartitionId1, bytesByPartitionId2), @@ -168,7 +168,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // 2 post-shuffle partition are needed. val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 2, 4) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) checkEstimation( coordinator, Array(bytesByPartitionId1, bytesByPartitionId2), @@ -179,7 +179,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // There are a few large pre-shuffle partitions. val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) - val expectedPartitionStartIndices = Array[Int](0, 2, 3) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) checkEstimation( coordinator, Array(bytesByPartitionId1, bytesByPartitionId2), @@ -228,7 +228,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // The number of post-shuffle partitions is determined by the coordinator. val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) - val expectedPartitionStartIndices = Array[Int](0, 2, 4) + val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) checkEstimation( coordinator, Array(bytesByPartitionId1, bytesByPartitionId2), @@ -272,13 +272,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") } - val spark = SparkSession.builder + val spark = SparkSession.builder() .config(sparkConf) .getOrCreate() try f(spark) finally spark.stop() } - Seq(Some(3), None).foreach { minNumPostShufflePartitions => + Seq(Some(5), None).foreach { minNumPostShufflePartitions => val testNameNote = minNumPostShufflePartitions match { case Some(numPartitions) => "(minNumPostShufflePartitions: 3)" case None => "" @@ -290,7 +290,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 20 as key", "id as value") - val agg = df.groupBy("key").count + val agg = df.groupBy("key").count() // Check the answer first. checkAnswer( @@ -308,7 +308,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { exchanges.foreach { case e: ShuffleExchange => assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 3) + assert(e.outputPartitioning.numPartitions === 5) case o => } @@ -316,7 +316,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { exchanges.foreach { case e: ShuffleExchange => assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 2) + assert(e.outputPartitioning.numPartitions === 3) case o => } } @@ -359,7 +359,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { exchanges.foreach { case e: ShuffleExchange => assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 3) + assert(e.outputPartitioning.numPartitions === 5) case o => } @@ -383,14 +383,14 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") - .count + .count() .toDF("key1", "cnt1") val df2 = spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") .groupBy("key2") - .count + .count() .toDF("key2", "cnt2") val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("cnt2")) @@ -415,13 +415,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { exchanges.foreach { case e: ShuffleExchange => assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 3) + assert(e.outputPartitioning.numPartitions === 5) case o => } case None => assert(exchanges.forall(_.coordinator.isDefined)) - assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(1, 2)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSet === Set(2, 3)) } } @@ -435,7 +435,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") - .count + .count() .toDF("key1", "cnt1") val df2 = spark @@ -467,13 +467,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { exchanges.foreach { case e: ShuffleExchange => assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 3) + assert(e.outputPartitioning.numPartitions === 5) case o => } case None => assert(exchanges.forall(_.coordinator.isDefined)) - assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(2, 3)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSet === Set(5, 3)) } }