Skip to content
Snippets Groups Projects
Commit e2614038 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-3408] Fixed Limit operator so it works with sort-based shuffle.

Author: Reynold Xin <rxin@apache.org>

Closes #2281 from rxin/sql-limit-sort and squashes the following commits:

1ef7780 [Reynold Xin] [SPARK-3408] Fixed Limit operator so it works with sort-based shuffle.
parent 39db1bfd
No related branches found
No related tags found
No related merge requests found
......@@ -20,10 +20,10 @@ package org.apache.spark.sql.execution
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.SQLContext
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
......@@ -96,6 +96,9 @@ case class Limit(limit: Int, child: SparkPlan)
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
/** We must copy rows when sort based shuffle is on */
private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
override def output = child.output
/**
......@@ -143,9 +146,15 @@ case class Limit(limit: Int, child: SparkPlan)
}
override def execute() = {
val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Boolean, Row]()
iter.take(limit).map(row => mutablePair.update(false, row))
val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter =>
iter.take(limit).map(row => (false, row.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Boolean, Row]()
iter.take(limit).map(row => mutablePair.update(false, row))
}
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment