From 3afe448d39dc4877b2f2c62b3059aeb3ced0bd96 Mon Sep 17 00:00:00 2001 From: Yin Huai <yhuai@databricks.com> Date: Wed, 21 Oct 2015 13:43:17 -0700 Subject: [PATCH] [SPARK-9740][SPARK-9592][SPARK-9210][SQL] Change the default behavior of First/Last to RESPECT NULLS. I am changing the default behavior of `First`/`Last` to respect null values (the SQL standard default behavior). https://issues.apache.org/jira/browse/SPARK-9740 Author: Yin Huai <yhuai@databricks.com> Closes #8113 from yhuai/firstLast. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 +- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/aggregate/functions.scala | 105 +++++++++++++++--- .../expressions/aggregate/utils.scala | 8 +- .../sql/catalyst/expressions/aggregates.scala | 95 ++++++++++++---- .../spark/sql/expressions/WindowSpec.scala | 13 ++- .../execution/AggregationQuerySuite.scala | 38 +++++++ 7 files changed, 219 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ab215407f7..98d6637c06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -113,7 +113,8 @@ trait CheckAnalysis { failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") + "Add to group by or wrap in first() (or first_value) if you don't care " + + "which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ba77b70a37..f73b24e363 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -179,7 +179,9 @@ object FunctionRegistry { expression[Average]("avg"), expression[Count]("count"), expression[First]("first"), + expression[First]("first_value"), expression[Last]("last"), + expression[Last]("last_value"), expression[Max]("max"), expression[Min]("min"), expression[Stddev]("stddev"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index c0bc7ec09c..515246d344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -21,6 +21,8 @@ import java.lang.{Long => JLong} import java.util import com.clearspring.analytics.hash.MurmurHash + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -118,7 +120,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = Cast(currentCount, LongType) } -case class First(child: Expression) extends DeclarativeAggregate { +/** + * Returns the first value of `child` for a group of rows. If the first value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ +case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def children: Seq[Expression] = child :: Nil @@ -135,24 +153,61 @@ case class First(child: Expression) extends DeclarativeAggregate { private val first = AttributeReference("first", child.dataType)() - override val aggBufferAttributes = first :: Nil + private val valueSet = AttributeReference("valueSet", BooleanType)() + + override val aggBufferAttributes = first :: valueSet :: Nil override val initialValues = Seq( - /* first = */ Literal.create(null, child.dataType) + /* first = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions = Seq( - /* first = */ If(IsNull(first), child, first) - ) + override val updateExpressions = { + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet, IsNull(child)), first, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) + ) + } else { + Seq( + /* first = */ If(valueSet, first, child), + /* valueSet = */ Literal.create(true, BooleanType) + ) + } + } - override val mergeExpressions = Seq( - /* first = */ If(IsNull(first.left), first.right, first.left) - ) + override val mergeExpressions = { + // For first, we can just check if valueSet.left is set to true. If it is set + // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // false, we are safe to do so because first.right will be null in this case). + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ Or(valueSet.left, valueSet.right) + ) + } override val evaluateExpression = first + + override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}" } -case class Last(child: Expression) extends DeclarativeAggregate { +/** + * Returns the last value of `child` for a group of rows. If the last value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ +case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def children: Seq[Expression] = child :: Nil @@ -175,15 +230,33 @@ case class Last(child: Expression) extends DeclarativeAggregate { /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions = Seq( - /* last = */ If(IsNull(child), last, child) - ) + override val updateExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(child), last, child) + ) + } else { + Seq( + /* last = */ child + ) + } + } - override val mergeExpressions = Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) + override val mergeExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + } else { + Seq( + /* last = */ last.right + ) + } + } override val evaluateExpression = last + + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" } case class Max(child: Expression) extends DeclarativeAggregate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index f656ccf13b..12bdab0915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -61,15 +61,15 @@ object Utils { mode = aggregate.Complete, isDistinct = true) - case expressions.First(child) => + case expressions.First(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), + aggregateFunction = aggregate.First(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) - case expressions.Last(child) => + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), + aggregateFunction = aggregate.Last(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f1c47f3904..95061c4635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} @@ -630,59 +631,113 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class First( + child: Expression, + ignoreNullsExpr: Expression) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" + override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})" override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() + val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() SplitEvaluation( - First(partialFirst.toAttribute), + First(partialFirst.toAttribute, ignoreNulls), partialFirst :: Nil) } - override def newInstance(): FirstFunction = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object First { + def apply(child: Expression): First = First(child, ignoreNulls = false) - var result: Any = null + def apply(child: Expression, ignoreNulls: Boolean): First = + First(child, Literal.create(ignoreNulls, BooleanType)) +} + +case class FirstFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. + + private[this] var result: Any = null + + private[this] var valueSet: Boolean = false override def update(input: InternalRow): Unit = { - // We ignore null values. - if (result == null) { - result = expr.eval(input) + if (!valueSet) { + val value = expr.eval(input) + // When we have not set the result, we will set the result if we respect nulls + // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. + if (!ignoreNulls || (ignoreNulls && value != null)) { + result = value + valueSet = true + } } } override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class Last( + child: Expression, + ignoreNullsExpr: Expression) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() + val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() SplitEvaluation( - Last(partialLast.toAttribute), + Last(partialLast.toAttribute, ignoreNulls), partialLast :: Nil) } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) } -case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object Last { + def apply(child: Expression): Last = Last(child, ignoreNulls = false) + + def apply(child: Expression, ignoreNulls: Boolean): Last = + Last(child, Literal.create(ignoreNulls, BooleanType)) +} + +case class LastFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. var result: Any = null override def update(input: InternalRow): Unit = { val value = expr.eval(input) - // We ignore null values. - if (value != null) { + if (!ignoreNulls || (ignoreNulls && value != null)) { result = value } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c3d2246297..8b9247adea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ @@ -149,13 +150,17 @@ class WindowSpec private[sql]( case Count(child) => WindowExpression( UnresolvedWindowFunction("count", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( + case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), + UnresolvedWindowFunction( + "first_value", + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( + case Last(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), + UnresolvedWindowFunction( + "last_value", + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case Min(child) => WindowExpression( UnresolvedWindowFunction("min", child :: Nil), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index c9e1bb1995..f38a3f63c3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -323,6 +323,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(11.125) :: Nil) } + test("first_value and last_value") { + // We force to use a single partition for the sort and aggregate to make result + // deterministic. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key) tmp + """.stripMargin), + Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp + """.stripMargin), + Row(3, null, 3, null, 3, 1, 3, 1) :: Nil) + } + } + test("udaf") { checkAnswer( sqlContext.sql( -- GitLab