Skip to content
Snippets Groups Projects
Commit 3f6a2bb3 authored by Josh Rosen's avatar Josh Rosen Committed by Herman van Hovell
Browse files

[SPARK-17515] CollectLimit.execute() should perform per-partition limits

## What changes were proposed in this pull request?

CollectLimit.execute() incorrectly omits per-partition limits, leading to performance regressions in case this case is hit (which should not happen in normal operation, but can occur in some cases (see #15068 for one example).

## How was this patch tested?

Regression test in SQLQuerySuite that asserts the number of records scanned from the input RDD.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #15070 from JoshRosen/SPARK-17515.
parent 46f5c201
No related branches found
No related tags found
No related merge requests found
......@@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
protected override def doExecute(): RDD[InternalRow] = {
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchange.prepareShuffleDependency(
child.execute(), child.output, SinglePartition, serializer))
locallyLimited, child.output, SinglePartition, serializer))
shuffled.mapPartitionsInternal(_.take(limit))
}
}
......
......@@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
data.selectExpr("`part.col1`", "`col.1`"))
}
}
test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") {
val numRecordsRead = spark.sparkContext.longAccumulator
spark.range(1, 100, 1, numPartitions = 10).map { x =>
numRecordsRead.add(1)
x
}.limit(1).queryExecution.toRdd.count()
assert(numRecordsRead.value === 10)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment