diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index afe52e6a667eba2dfd5651cd1699b9635f3d3d86..a6fe730f6dad49e266d891c3cee68d3381c2fd1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + private[this] val buffer = new Array[Any](expressions.size) + expressions.foreach(_.foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).eval(input) + // Store the result into buffer first, to make the projection atomic (needed by aggregation) + buffer(i) = exprArray(i).eval(input) + i += 1 + } + i = 0 + while (i < exprArray.length) { + mutableRow(i) = buffer(i) i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 5d2eb7b017ab9bc1ada461a1a869875b1061c8be..f2c3eca0951155ca87f03e576f928cade91b57b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate { case _ => DoubleType } - private val currentSum = AttributeReference("currentSum", sumDataType)() - private val currentCount = AttributeReference("currentCount", LongType)() + private val sum = AttributeReference("sum", sumDataType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentSum :: currentCount :: Nil + override val aggBufferAttributes = sum :: count :: Nil override val initialValues = Seq( - /* currentSum = */ Cast(Literal(0), sumDataType), - /* currentCount = */ Literal(0L) + /* sum = */ Cast(Literal(0), sumDataType), + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentSum = */ + /* sum = */ Add( - currentSum, + sum, Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentSum = */ currentSum.left + currentSum.right, - /* currentCount = */ currentCount.left + currentCount.right + /* sum = */ sum.left + sum.right, + /* count = */ count.left + count.right ) - // If all input are nulls, currentCount will be 0 and we will get null after the division. + // If all input are nulls, count will be 0 and we will get null after the division. override val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => - Cast(currentSum, resultType) / Cast(currentCount, resultType) + Cast(sum, resultType) / Cast(count, resultType) } } @@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val currentCount = AttributeReference("currentCount", LongType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentCount :: Nil + override val aggBufferAttributes = count :: Nil override val initialValues = Seq( - /* currentCount = */ Literal(0L) + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentCount = */ currentCount.left + currentCount.right + /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(currentCount, LongType) + override val evaluateExpression = Cast(count, LongType) } /** @@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { private val resultType = DoubleType - private val preCount = AttributeReference("preCount", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() - private val preAvg = AttributeReference("preAvg", resultType)() - private val currentAvg = AttributeReference("currentAvg", resultType)() - private val currentMk = AttributeReference("currentMk", resultType)() + private val count = AttributeReference("count", resultType)() + private val avg = AttributeReference("avg", resultType)() + private val mk = AttributeReference("mk", resultType)() - override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil + override val aggBufferAttributes = count :: avg :: mk :: Nil override val initialValues = Seq( - /* preCount = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType), - /* preAvg = */ Cast(Literal(0), resultType), - /* currentAvg = */ Cast(Literal(0), resultType), - /* currentMk = */ Cast(Literal(0), resultType) + /* count = */ Cast(Literal(0), resultType), + /* avg = */ Cast(Literal(0), resultType), + /* mk = */ Cast(Literal(0), resultType) ) override val updateExpressions = { + val value = Cast(child, resultType) + val newCount = count + Cast(Literal(1), resultType) // update average // avg = avg + (value - avg)/count - def avgAdd: Expression = { - currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) - } + val newAvg = avg + (value - avg) / newCount // update sum of square of difference from mean // Mk = Mk + (value - preAvg) * (value - updatedAvg) - def mkAdd: Expression = { - val delta1 = Cast(child, resultType) - preAvg - val delta2 = Cast(child, resultType) - currentAvg - currentMk + (delta1 * delta2) - } + val newMk = mk + (value - avg) * (value - newAvg) Seq( - /* preCount = */ If(IsNull(child), preCount, currentCount), - /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), - /* preAvg = */ If(IsNull(child), preAvg, currentAvg), - /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), - /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + /* count = */ If(IsNull(child), count, newCount), + /* avg = */ If(IsNull(child), avg, newAvg), + /* mk = */ If(IsNull(child), mk, newMk) ) } override val mergeExpressions = { // count merge - def countMerge: Expression = { - currentCount.left + currentCount.right - } + val newCount = count.left + count.right // average merge - def avgMerge: Expression = { - ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) - } + val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount // update sum of square differences - def mkMerge: Expression = { - val avgDelta = currentAvg.right - preAvg - val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / - (preCount + currentCount.right) - - currentMk.left + currentMk.right + mkDelta + val newMk = { + val avgDelta = avg.right - avg.left + val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount + mk.left + mk.right + mkDelta } Seq( - /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), - /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), - /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), - /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) + /* count = */ If(IsNull(count.left), count.right, + If(IsNull(count.right), count.left, newCount)), + /* avg = */ If(IsNull(avg.left), avg.right, + If(IsNull(avg.right), avg.left, newAvg)), + /* mk = */ If(IsNull(mk.left), mk.right, + If(IsNull(mk.right), mk.left, newMk)) ) } override val evaluateExpression = { - // when currentCount == 0, return null - // when currentCount == 1, return 0 - // when currentCount >1 - // stddev_samp = sqrt (currentMk/(currentCount -1)) - // stddev_pop = sqrt (currentMk/currentCount) - val varCol = { + // when count == 0, return null + // when count == 1, return 0 + // when count >1 + // stddev_samp = sqrt (mk/(count -1)) + // stddev_pop = sqrt (mk/count) + val varCol = if (isSample) { - currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) - } - else { - currentMk / currentCount + mk / Cast((count - Cast(Literal(1), resultType)), resultType) + } else { + mk / count } - } - If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), Cast(Sqrt(varCol), resultType))) } } @@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private val sumDataType = resultType - private val currentSum = AttributeReference("currentSum", sumDataType)() + private val sum = AttributeReference("sum", sumDataType)() private val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = currentSum :: Nil + override val aggBufferAttributes = sum :: Nil override val initialValues = Seq( - /* currentSum = */ Literal.create(null, sumDataType) + /* sum = */ Literal.create(null, sumDataType) ) override val updateExpressions = Seq( - /* currentSum = */ - Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) override val mergeExpressions = { - val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( - /* currentSum = */ - Coalesce(Seq(add, currentSum.left)) + /* sum = */ + Coalesce(Seq(add, sum.left)) ) } - override val evaluateExpression = Cast(currentSum, resultType) + override val evaluateExpression = Cast(sum, resultType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e8ee64756d5d0dca4800332b4ca60bef42638cd1..4b66069b5f55aefb05939c8a5b6ff624635555f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } + val updates = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => if (e.dataType.isInstanceOf[DecimalType]) { // Can't call setNullAt on DecimalType, because we need to keep the offset s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } else { s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { mutableRow.setNullAt($i); } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } } + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" public Object generate($exprType[] expr) { @@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public Object apply(Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections + // copy all the results into MutableRow + $allUpdates return mutableRow; } }