Skip to content
Snippets Groups Projects
Commit 1283c3d1 authored by Wenchen Fan's avatar Wenchen Fan Committed by Herman van Hovell
Browse files

[SPARK-20725][SQL] partial aggregate should behave correctly for sameResult

## What changes were proposed in this pull request?

For aggregate function with `PartialMerge` or `Final` mode, the input is aggregate buffers instead of the actual children expressions. So the actual children expressions won't affect the result, we should normalize the expr id for them.

## How was this patch tested?

a new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17964 from cloud-fan/tmp.
parent 3f98375d
No related branches found
No related tags found
No related merge requests found
...@@ -105,12 +105,22 @@ case class AggregateExpression( ...@@ -105,12 +105,22 @@ case class AggregateExpression(
} }
// We compute the same thing regardless of our final result. // We compute the same thing regardless of our final result.
override lazy val canonicalized: Expression = override lazy val canonicalized: Expression = {
val normalizedAggFunc = mode match {
// For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
// and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
case PartialMerge | Final => aggregateFunction.transform {
case a: AttributeReference => a.withExprId(ExprId(0))
}
case Partial | Complete => aggregateFunction
}
AggregateExpression( AggregateExpression(
aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode, mode,
isDistinct, isDistinct,
ExprId(0)) ExprId(0))
}
override def children: Seq[Expression] = aggregateFunction :: Nil override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType override def dataType: DataType = aggregateFunction.dataType
......
...@@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT ...@@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
def recursiveTransform(arg: Any): AnyRef = arg match { def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpression(e) case e: Expression => transformExpression(e)
case Some(e: Expression) => Some(transformExpression(e)) case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform) case seq: Traversable[_] => seq.map(recursiveTransform)
...@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT ...@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
productIterator.flatMap { productIterator.flatMap {
case e: Expression => e :: Nil case e: Expression => e :: Nil
case Some(e: Expression) => e :: Nil case s: Some[_] => seqToExpressions(s.toSeq)
case seq: Traversable[_] => seqToExpressions(seq) case seq: Traversable[_] => seqToExpressions(seq)
case other => Nil case other => Nil
}.toSeq }.toSeq
......
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
package org.apache.spark.sql.execution package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
/** /**
* Tests for the sameResult function for [[SparkPlan]]s. * Tests for the sameResult function for [[SparkPlan]]s.
*/ */
class SameResultSuite extends QueryTest with SharedSQLContext { class SameResultSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("FileSourceScanExec: different orders of data filters and partition filters") { test("FileSourceScanExec: different orders of data filters and partition filters") {
withTempPath { path => withTempPath { path =>
...@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext { ...@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext {
df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
.asInstanceOf[FileSourceScanExec] .asInstanceOf[FileSourceScanExec]
} }
test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
val df1 = spark.range(10).agg(sum($"id"))
val df2 = spark.range(10).agg(sum($"id"))
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
val df3 = spark.range(10).agg(sumDistinct($"id"))
val df4 = spark.range(10).agg(sumDistinct($"id"))
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment