Skip to content
Snippets Groups Projects
Commit fbc26925 authored by donnyzone's avatar donnyzone Committed by gatorsmile
Browse files

[SPARK-19471][SQL] AggregationIterator does not initialize the generated...

[SPARK-19471][SQL] AggregationIterator does not initialize the generated result projection before using it

## What changes were proposed in this pull request?

Recently, we have also encountered such NPE issues in our production environment as described in:
https://issues.apache.org/jira/browse/SPARK-19471

This issue can be reproduced by the following examples:
` val df = spark.createDataFrame(Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4))).toDF("x", "y")

//HashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false
df.groupBy("x").agg(rand(),sum("y")).show()

//ObjectHashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false
df.groupBy("x").agg(rand(),collect_list("y")).show()

//SortAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false &&SQLConf.USE_OBJECT_HASH_AGG.key=false
df.groupBy("x").agg(rand(),collect_list("y")).show()`
`

This PR is based on PR-16820(https://github.com/apache/spark/pull/16820) with test cases for all aggregation paths. We want to push it forward.

> When AggregationIterator generates result projection, it does not call the initialize method of the Projection class. This will cause a runtime NullPointerException when the projection involves nondeterministic expressions.

## How was this patch tested?

unit test
verified in production environment

Author: donnyzone <wellfengzhu@gmail.com>

Closes #18920 from DonnyZone/Branch-spark-19471.
parent 0326b69c
No related branches found
No related tags found
No related merge requests found
Showing with 63 additions and 3 deletions
......@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
* is used to generate result.
*/
abstract class AggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
inputAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
......@@ -217,6 +218,7 @@ abstract class AggregationIterator(
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
resultProjection.initialize(partIndex)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
// Generate results for all expression-based aggregate functions.
......@@ -235,6 +237,7 @@ abstract class AggregationIterator(
val resultProjection = UnsafeProjection.create(
groupingAttributes ++ bufferAttributes,
groupingAttributes ++ bufferAttributes)
resultProjection.initialize(partIndex)
// TypedImperativeAggregate stores generic object in aggregation buffer, and requires
// calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
......@@ -256,6 +259,7 @@ abstract class AggregationIterator(
} else {
// Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
resultProjection.initialize(partIndex)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
resultProjection(currentGroupingKey)
}
......
......@@ -96,7 +96,7 @@ case class HashAggregateExec(
val spillSize = longMetric("spillSize")
val avgHashProbe = longMetric("avgHashProbe")
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
......@@ -106,6 +106,7 @@ case class HashAggregateExec(
} else {
val aggregationIterator =
new TungstenAggregationIterator(
partIndex,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
......
......@@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
class ObjectAggregationIterator(
partIndex: Int,
outputAttributes: Seq[Attribute],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
......@@ -43,6 +44,7 @@ class ObjectAggregationIterator(
fallbackCountThreshold: Int,
numOutputRows: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
......
......@@ -98,7 +98,7 @@ case class ObjectHashAggregateExec(
val numOutputRows = longMetric("numOutputRows")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input kvIterator is empty,
......@@ -107,6 +107,7 @@ case class ObjectHashAggregateExec(
} else {
val aggregationIterator =
new ObjectAggregationIterator(
partIndex,
child.output,
groupingExpressions,
aggregateExpressions,
......
......@@ -74,7 +74,7 @@ case class SortAggregateExec(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
// Because the constructor of an aggregation iterator will read at least the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
......@@ -84,6 +84,7 @@ case class SortAggregateExec(
Iterator[UnsafeRow]()
} else {
val outputIter = new SortBasedAggregationIterator(
partIndex,
groupingExpressions,
child.output,
iter,
......
......@@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
* sorted by values of [[groupingExpressions]].
*/
class SortBasedAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
......@@ -37,6 +38,7 @@ class SortBasedAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
numOutputRows: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
valueAttributes,
aggregateExpressions,
......
......@@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator
* - Part 8: A utility function used to generate a result when there is no
* input and there is no grouping expression.
*
* @param partIndex
* index of the partition
* @param groupingExpressions
* expressions for grouping keys
* @param aggregateExpressions
......@@ -77,6 +79,7 @@ import org.apache.spark.unsafe.KVIterator
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
......@@ -91,6 +94,7 @@ class TungstenAggregationIterator(
spillSize: SQLMetric,
avgHashProbe: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
......
......@@ -24,6 +24,8 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
......@@ -449,6 +451,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}
private def assertNoExceptions(c: Column): Unit = {
for ((wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")
// HashAggregate test case
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
val hashAggPlan = hashAggDF.queryExecution.executedPlan
if (wholeStage) {
assert(hashAggPlan.find {
case WholeStageCodegenExec(_: HashAggregateExec) => true
case _ => false
}.isDefined)
} else {
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
}
hashAggDF.collect()
// ObjectHashAggregate and SortAggregate test case
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
objHashAggOrSortAggDF.collect()
}
}
}
test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertNoExceptions)
}
test("SPARK-21281 use string types by default if array and map have no argument") {
val ds = spark.range(1)
var expectedSchema = new StructType()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment