Skip to content
Snippets Groups Projects
Commit 5fd54b99 authored by Herman van Hovell's avatar Herman van Hovell Committed by Yin Huai
Browse files

[SPARK-17758][SQL] Last returns wrong result in case of empty partition

## What changes were proposed in this pull request?
The result of the `Last` function can be wrong when the last partition processed is empty. It can return `null` instead of the expected value. For example, this can happen when we process partitions in the following order:
```
- Partition 1 [Row1, Row2]
- Partition 2 [Row3]
- Partition 3 []
```
In this case the `Last` function will currently return a null, instead of the value of `Row3`.

This PR fixes this by adding a `valueSet` flag to the `Last` function.

## How was this patch tested?
We only used end to end tests for `DeclarativeAggregateFunction`s. I have added an evaluator for these functions so we can tests them in catalyst. I have added a `LastTestSuite` to test the `Last` aggregate function.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15348 from hvanhovell/SPARK-17758.
parent 221b418b
No related branches found
No related tags found
No related merge requests found
......@@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
private lazy val last = AttributeReference("last", child.dataType)()
override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil
override lazy val initialValues: Seq[Literal] = Seq(
/* last = */ Literal.create(null, child.dataType)
/* last = */ Literal.create(null, child.dataType),
/* valueSet = */ Literal.create(false, BooleanType)
)
override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(child), last, child)
/* last = */ If(IsNull(child), last, child),
/* valueSet = */ Or(valueSet, IsNotNull(child))
)
} else {
Seq(
/* last = */ child
/* last = */ child,
/* valueSet = */ Literal.create(true, BooleanType)
)
}
}
override lazy val mergeExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
)
} else {
Seq(
/* last = */ last.right
)
}
// Prefer the right hand expression if it has been set.
Seq(
/* last = */ If(valueSet.right, last.right, last.left),
/* valueSet = */ Or(valueSet.right, valueSet.left)
)
}
override lazy val evaluateExpression: AttributeReference = last
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
/**
* Evaluator for a [[DeclarativeAggregate]].
*/
case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) {
lazy val initializer = GenerateSafeProjection.generate(function.initialValues)
lazy val updater = GenerateSafeProjection.generate(
function.updateExpressions,
function.aggBufferAttributes ++ input)
lazy val merger = GenerateSafeProjection.generate(
function.mergeExpressions,
function.aggBufferAttributes ++ function.inputAggBufferAttributes)
lazy val evaluator = GenerateSafeProjection.generate(
function.evaluateExpression :: Nil,
function.aggBufferAttributes)
def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy()
def update(values: InternalRow*): InternalRow = {
val joiner = new JoinedRow
val buffer = values.foldLeft(initialize()) { (buffer, input) =>
updater(joiner(buffer, input))
}
buffer.copy()
}
def merge(buffers: InternalRow*): InternalRow = {
val joiner = new JoinedRow
val buffer = buffers.foldLeft(initialize()) { (left, right) =>
merger(joiner(left, right))
}
buffer.copy()
}
def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy()
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.types.IntegerType
class LastTestSuite extends SparkFunSuite {
val input = AttributeReference("input", IntegerType, nullable = true)()
val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input))
val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input))
test("empty buffer") {
assert(evaluator.initialize() === InternalRow(null, false))
}
test("update") {
val result = evaluator.update(
InternalRow(1),
InternalRow(9),
InternalRow(-1))
assert(result === InternalRow(-1, true))
}
test("update - ignore nulls") {
val result1 = evaluatorIgnoreNulls.update(
InternalRow(null),
InternalRow(9),
InternalRow(null))
assert(result1 === InternalRow(9, true))
val result2 = evaluatorIgnoreNulls.update(
InternalRow(null),
InternalRow(null))
assert(result2 === InternalRow(null, false))
}
test("merge") {
// Empty merge
val p0 = evaluator.initialize()
assert(evaluator.merge(p0) === InternalRow(null, false))
// Single merge
val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
assert(evaluator.merge(p1) === p1)
// Multiple merges.
val p2 = evaluator.update(InternalRow(2), InternalRow(10))
assert(evaluator.merge(p1, p2) === p2)
// Empty partitions (p0 is empty)
assert(evaluator.merge(p1, p0, p2) === p2)
assert(evaluator.merge(p2, p1, p0) === p1)
}
test("merge - ignore nulls") {
// Multi merges
val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
assert(evaluatorIgnoreNulls.merge(p1, p2) === p1)
}
test("eval") {
// Null Eval
assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null))
assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null))
// Empty Eval
val p0 = evaluator.initialize()
assert(evaluator.eval(p0) === InternalRow(null))
// Update - Eval
val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
assert(evaluator.eval(p1) === InternalRow(-99))
// Update - Merge - Eval
val p2 = evaluator.update(InternalRow(2), InternalRow(10))
val m1 = evaluator.merge(p1, p0, p2)
assert(evaluator.eval(m1) === InternalRow(10))
// Update - Merge - Eval (empty partition at the end)
val m2 = evaluator.merge(p2, p1, p0)
assert(evaluator.eval(m2) === InternalRow(-99))
}
test("eval - ignore nulls") {
// Update - Merge - Eval
val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
val m1 = evaluatorIgnoreNulls.merge(p1, p2)
assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment