Skip to content
Snippets Groups Projects
Commit 1283c3d1 authored by Wenchen Fan's avatar Wenchen Fan Committed by Herman van Hovell
Browse files

[SPARK-20725][SQL] partial aggregate should behave correctly for sameResult

## What changes were proposed in this pull request?

For aggregate function with `PartialMerge` or `Final` mode, the input is aggregate buffers instead of the actual children expressions. So the actual children expressions won't affect the result, we should normalize the expr id for them.

## How was this patch tested?

a new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17964 from cloud-fan/tmp.
parent 3f98375d
No related branches found
No related tags found
No related merge requests found
......@@ -105,12 +105,22 @@ case class AggregateExpression(
}
// We compute the same thing regardless of our final result.
override lazy val canonicalized: Expression =
override lazy val canonicalized: Expression = {
val normalizedAggFunc = mode match {
// For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
// and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
case PartialMerge | Final => aggregateFunction.transform {
case a: AttributeReference => a.withExprId(ExprId(0))
}
case Partial | Complete => aggregateFunction
}
AggregateExpression(
aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
ExprId(0))
}
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
......
......@@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpression(e)
case Some(e: Expression) => Some(transformExpression(e))
case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
......@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
productIterator.flatMap {
case e: Expression => e :: Nil
case Some(e: Expression) => e :: Nil
case s: Some[_] => seqToExpressions(s.toSeq)
case seq: Traversable[_] => seqToExpressions(seq)
case other => Nil
}.toSeq
......
......@@ -18,12 +18,14 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
/**
* Tests for the sameResult function for [[SparkPlan]]s.
*/
class SameResultSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("FileSourceScanExec: different orders of data filters and partition filters") {
withTempPath { path =>
......@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext {
df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
.asInstanceOf[FileSourceScanExec]
}
test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
val df1 = spark.range(10).agg(sum($"id"))
val df2 = spark.range(10).agg(sum($"id"))
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
val df3 = spark.range(10).agg(sumDistinct($"id"))
val df4 = spark.range(10).agg(sumDistinct($"id"))
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment