Skip to content
Snippets Groups Projects
Commit 3b0babad authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Reynold Xin
Browse files

[SPARK-1915] [SQL] AverageFunction should not count if the evaluated value is null.

Average values are difference between the calculation is done partially or not partially.
Because `AverageFunction` (in not-partially calculation) counts even if the evaluated value is null.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #862 from ueshin/issues/SPARK-1915 and squashes the following commits:

b1ff3c0 [Takuya UESHIN] Modify AverageFunction not to count if the evaluated value is null.
parent d1375a2b
No related branches found
No related tags found
No related merge requests found
......@@ -281,14 +281,17 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
private def addFunction(value: Any) = Add(sum, Literal(value))
override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
override def update(input: Row): Unit = {
count += 1
sum.update(addFunction, input)
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1
sum.update(addFunction(evaluatedExpr), input)
}
}
}
......
......@@ -115,6 +115,16 @@ class DslQuerySuite extends QueryTest {
2.0)
}
test("null average") {
checkAnswer(
testData3.groupBy()(Average('b)),
2.0)
checkAnswer(
testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
(2.0, 1) :: Nil)
}
test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment