Skip to content
Snippets Groups Projects
Commit 67e23b39 authored by Davies Liu's avatar Davies Liu Committed by Michael Armbrust
Browse files

[SPARK-10429] [SQL] make mutableProjection atomic

Right now, SQL's mutable projection updates every value of the mutable project after it evaluates the corresponding expression. This makes the behavior of MutableProjection confusing and complicate the implementation of common aggregate functions like stddev because developers need to be aware that when evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the mutable row has already been updated.

This PR make the MutableProjection atomic, by generating all the results of expressions first, then copy them into mutableRow.

Had run a mircro-benchmark, there is no notable performance difference between using class members and local variables.

cc yhuai

Author: Davies Liu <davies@databricks.com>

Closes #9422 from davies/atomic_mutable and squashes the following commits:

bbc1758 [Davies Liu] support wide table
8a0ae14 [Davies Liu] fix bug
bec07da [Davies Liu] refactor
2891628 [Davies Liu] make mutableProjection atomic
parent d728d5c9
No related branches found
No related tags found
No related merge requests found
...@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions ...@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/** /**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
...@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu ...@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema))) this(expressions.map(BindReferences.bindReference(_, inputSchema)))
private[this] val buffer = new Array[Any](expressions.size)
expressions.foreach(_.foreach { expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues() case n: Nondeterministic => n.setInitialValues()
case _ => case _ =>
...@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu ...@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
override def apply(input: InternalRow): InternalRow = { override def apply(input: InternalRow): InternalRow = {
var i = 0 var i = 0
while (i < exprArray.length) { while (i < exprArray.length) {
mutableRow(i) = exprArray(i).eval(input) // Store the result into buffer first, to make the projection atomic (needed by aggregation)
buffer(i) = exprArray(i).eval(input)
i += 1
}
i = 0
while (i < exprArray.length) {
mutableRow(i) = buffer(i)
i += 1 i += 1
} }
mutableRow mutableRow
......
...@@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate { ...@@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate {
case _ => DoubleType case _ => DoubleType
} }
private val currentSum = AttributeReference("currentSum", sumDataType)() private val sum = AttributeReference("sum", sumDataType)()
private val currentCount = AttributeReference("currentCount", LongType)() private val count = AttributeReference("count", LongType)()
override val aggBufferAttributes = currentSum :: currentCount :: Nil override val aggBufferAttributes = sum :: count :: Nil
override val initialValues = Seq( override val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), sumDataType), /* sum = */ Cast(Literal(0), sumDataType),
/* currentCount = */ Literal(0L) /* count = */ Literal(0L)
) )
override val updateExpressions = Seq( override val updateExpressions = Seq(
/* currentSum = */ /* sum = */
Add( Add(
currentSum, sum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) /* count = */ If(IsNull(child), count, count + 1L)
) )
override val mergeExpressions = Seq( override val mergeExpressions = Seq(
/* currentSum = */ currentSum.left + currentSum.right, /* sum = */ sum.left + sum.right,
/* currentCount = */ currentCount.left + currentCount.right /* count = */ count.left + count.right
) )
// If all input are nulls, currentCount will be 0 and we will get null after the division. // If all input are nulls, count will be 0 and we will get null after the division.
override val evaluateExpression = child.dataType match { override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) => case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss // increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4) val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ => case _ =>
Cast(currentSum, resultType) / Cast(currentCount, resultType) Cast(sum, resultType) / Cast(count, resultType)
} }
} }
...@@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { ...@@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
// Expected input data type. // Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val currentCount = AttributeReference("currentCount", LongType)() private val count = AttributeReference("count", LongType)()
override val aggBufferAttributes = currentCount :: Nil override val aggBufferAttributes = count :: Nil
override val initialValues = Seq( override val initialValues = Seq(
/* currentCount = */ Literal(0L) /* count = */ Literal(0L)
) )
override val updateExpressions = Seq( override val updateExpressions = Seq(
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) /* count = */ If(IsNull(child), count, count + 1L)
) )
override val mergeExpressions = Seq( override val mergeExpressions = Seq(
/* currentCount = */ currentCount.left + currentCount.right /* count = */ count.left + count.right
) )
override val evaluateExpression = Cast(currentCount, LongType) override val evaluateExpression = Cast(count, LongType)
} }
/** /**
...@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ...@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
private val resultType = DoubleType private val resultType = DoubleType
private val preCount = AttributeReference("preCount", resultType)() private val count = AttributeReference("count", resultType)()
private val currentCount = AttributeReference("currentCount", resultType)() private val avg = AttributeReference("avg", resultType)()
private val preAvg = AttributeReference("preAvg", resultType)() private val mk = AttributeReference("mk", resultType)()
private val currentAvg = AttributeReference("currentAvg", resultType)()
private val currentMk = AttributeReference("currentMk", resultType)()
override val aggBufferAttributes = preCount :: currentCount :: preAvg :: override val aggBufferAttributes = count :: avg :: mk :: Nil
currentAvg :: currentMk :: Nil
override val initialValues = Seq( override val initialValues = Seq(
/* preCount = */ Cast(Literal(0), resultType), /* count = */ Cast(Literal(0), resultType),
/* currentCount = */ Cast(Literal(0), resultType), /* avg = */ Cast(Literal(0), resultType),
/* preAvg = */ Cast(Literal(0), resultType), /* mk = */ Cast(Literal(0), resultType)
/* currentAvg = */ Cast(Literal(0), resultType),
/* currentMk = */ Cast(Literal(0), resultType)
) )
override val updateExpressions = { override val updateExpressions = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
// update average // update average
// avg = avg + (value - avg)/count // avg = avg + (value - avg)/count
def avgAdd: Expression = { val newAvg = avg + (value - avg) / newCount
currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
}
// update sum of square of difference from mean // update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg) // Mk = Mk + (value - preAvg) * (value - updatedAvg)
def mkAdd: Expression = { val newMk = mk + (value - avg) * (value - newAvg)
val delta1 = Cast(child, resultType) - preAvg
val delta2 = Cast(child, resultType) - currentAvg
currentMk + (delta1 * delta2)
}
Seq( Seq(
/* preCount = */ If(IsNull(child), preCount, currentCount), /* count = */ If(IsNull(child), count, newCount),
/* currentCount = */ If(IsNull(child), currentCount, /* avg = */ If(IsNull(child), avg, newAvg),
Add(currentCount, Cast(Literal(1), resultType))), /* mk = */ If(IsNull(child), mk, newMk)
/* preAvg = */ If(IsNull(child), preAvg, currentAvg),
/* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
/* currentMk = */ If(IsNull(child), currentMk, mkAdd)
) )
} }
override val mergeExpressions = { override val mergeExpressions = {
// count merge // count merge
def countMerge: Expression = { val newCount = count.left + count.right
currentCount.left + currentCount.right
}
// average merge // average merge
def avgMerge: Expression = { val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount
((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
(preCount + currentCount.right)
}
// update sum of square differences // update sum of square differences
def mkMerge: Expression = { val newMk = {
val avgDelta = currentAvg.right - preAvg val avgDelta = avg.right - avg.left
val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
(preCount + currentCount.right) mk.left + mk.right + mkDelta
currentMk.left + currentMk.right + mkDelta
} }
Seq( Seq(
/* preCount = */ If(IsNull(currentCount.left), /* count = */ If(IsNull(count.left), count.right,
Cast(Literal(0), resultType), currentCount.left), If(IsNull(count.right), count.left, newCount)),
/* currentCount = */ If(IsNull(currentCount.left), currentCount.right, /* avg = */ If(IsNull(avg.left), avg.right,
If(IsNull(currentCount.right), currentCount.left, countMerge)), If(IsNull(avg.right), avg.left, newAvg)),
/* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), /* mk = */ If(IsNull(mk.left), mk.right,
/* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, If(IsNull(mk.right), mk.left, newMk))
If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
/* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
If(IsNull(currentMk.right), currentMk.left, mkMerge))
) )
} }
override val evaluateExpression = { override val evaluateExpression = {
// when currentCount == 0, return null // when count == 0, return null
// when currentCount == 1, return 0 // when count == 1, return 0
// when currentCount >1 // when count >1
// stddev_samp = sqrt (currentMk/(currentCount -1)) // stddev_samp = sqrt (mk/(count -1))
// stddev_pop = sqrt (currentMk/currentCount) // stddev_pop = sqrt (mk/count)
val varCol = { val varCol =
if (isSample) { if (isSample) {
currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) mk / Cast((count - Cast(Literal(1), resultType)), resultType)
} } else {
else { mk / count
currentMk / currentCount
} }
}
If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType))) Cast(Sqrt(varCol), resultType)))
} }
} }
...@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate { ...@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private val sumDataType = resultType private val sumDataType = resultType
private val currentSum = AttributeReference("currentSum", sumDataType)() private val sum = AttributeReference("sum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType) private val zero = Cast(Literal(0), sumDataType)
override val aggBufferAttributes = currentSum :: Nil override val aggBufferAttributes = sum :: Nil
override val initialValues = Seq( override val initialValues = Seq(
/* currentSum = */ Literal.create(null, sumDataType) /* sum = */ Literal.create(null, sumDataType)
) )
override val updateExpressions = Seq( override val updateExpressions = Seq(
/* currentSum = */ /* sum = */
Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
) )
override val mergeExpressions = { override val mergeExpressions = {
val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq( Seq(
/* currentSum = */ /* sum = */
Coalesce(Seq(add, currentSum.left)) Coalesce(Seq(add, sum.left))
) )
} }
override val evaluateExpression = Cast(currentSum, resultType) override val evaluateExpression = Cast(sum, resultType)
} }
/** /**
......
...@@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ...@@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => "" case (NoOp, _) => ""
case (e, i) => case (e, i) =>
val evaluationCode = e.gen(ctx) val evaluationCode = e.gen(ctx)
val isNull = s"isNull_$i"
val value = s"value_$i"
ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
ctx.addMutableState(ctx.javaType(e.dataType), value,
s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
${evaluationCode.code}
this.$isNull = ${evaluationCode.isNull};
this.$value = ${evaluationCode.value};
"""
}
val updates = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
if (e.dataType.isInstanceOf[DecimalType]) { if (e.dataType.isInstanceOf[DecimalType]) {
// Can't call setNullAt on DecimalType, because we need to keep the offset // Can't call setNullAt on DecimalType, because we need to keep the offset
s""" s"""
${evaluationCode.code} if (this.isNull_$i) {
if (${evaluationCode.isNull}) {
${ctx.setColumn("mutableRow", e.dataType, i, null)}; ${ctx.setColumn("mutableRow", e.dataType, i, null)};
} else { } else {
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
} }
""" """
} else { } else {
s""" s"""
${evaluationCode.code} if (this.isNull_$i) {
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i); mutableRow.setNullAt($i);
} else { } else {
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
} }
""" """
} }
} }
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
val code = s""" val code = s"""
public Object generate($exprType[] expr) { public Object generate($exprType[] expr) {
...@@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ...@@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
public Object apply(Object _i) { public Object apply(Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allProjections $allProjections
// copy all the results into MutableRow
$allUpdates
return mutableRow; return mutableRow;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment